본문 바로가기
공부/실전문제연구단

[실전문제연구단] GREET 코드 뜯어보기 - 3. Edge Discriminator 모델

by 박영귤 2024. 1. 13.

class Edge_Discriminator(nn.Module):
    def __init__(self, nnodes, input_dim, alpha, sparse, hidden_dim=128, temperature=1.0, bias=0.0 + 0.0001):
        """
        그래프 엣지를 판별하는 모델인 Edge Discriminator의 클래스입니다.

        Parameters:
        - nnodes: 노드의 수
        - input_dim: 입력 특성의 차원
        - alpha: Negative sampling에서 사용되는 가중치
        - sparse: 희소 그래프 여부
        - hidden_dim: 은닉 레이어의 차원 (기본값: 128)
        - temperature: Gumbel-Softmax 샘플링의 온도 매개변수 (기본값: 1.0)
        - bias: Gumbel-Softmax 샘플링에서의 바이어스 (기본값: 0.0001)

        Returns:
        - None
        """
        super(Edge_Discriminator, self).__init__()

        # 임베딩 레이어 및 MLP 레이어 정의
        self.embedding_layers = nn.ModuleList()
        self.embedding_layers.append(nn.Linear(input_dim, hidden_dim))
        self.edge_mlp = nn.Linear(hidden_dim * 2, 1)

        # 모델 파라미터 설정
        self.temperature = temperature
        self.bias = bias
        self.nnodes = nnodes
        self.sparse = sparse
        self.alpha = alpha

모델은 위처럼 구현되어있다. 노드를 임베딩해주는 layer, 노드 두 개를 넣어서 edge정보를 계산해주는 edge_mlp레이어가 있다.


    def weight_forward(self, features, edges):
        """
        모델을 통해 엣지의 가중치를 계산합니다.

        Parameters:
        - features: 노드 특성
        - edges: 그래프의 엣지 정보

        Returns:
        - Gumbel-Softmax로 샘플링된 엣지의 가중치 (Low Probability, High Probability)
        """
        embeddings = self.get_node_embedding(features)
        edges_weights_raw = self.get_edge_weight(embeddings, edges)
        weights_lp = self.gumbel_sampling(edges_weights_raw)
        weights_hp = 1 - weights_lp
        # high probability가 heomo, low probability가 hetero
        return weights_lp, weights_hp

이걸 보면 어디에 뭐가 사용되는지 알 수 있다.

 

get_node_embedding으로 노드의 feature vector를 노드임베딩으로 변환해준다. 이 함수에선 embedding_layer가 사용된다.

 

get_dege_weight에는 노드 임베딩과 엣지를 넣는다. edge_mlp를 통과해서 엣지의 가중치가 생성된다. 이때 생성된 edge_weight_raw에는 10556개의 weight(스칼라값)이 들어있다.

 

gumbel_sampling은 수학적으로 뭔가 sampling을 하는 것이다. 굼벨 샘플링은 Generative Models이나 강화 학습(Reinforcement Learning)에 많이 사용된다고 한다. sigmoid를 사용했기 때문에 0~1 사이이다. 이를 통과하면 엣지가 low probability한 것에 대한 가중치를 나타내준다. 여기서 high probability한 것에 대한 가중치를 얻으려면 1에서 빼주면 된다. (low probability는 heterophily, high probability는 homophily라고 생각하면 될 듯 하다.)

굼벨 분포(gumbel distribution)

 

위에서 얻은 weight_lp, weight_hp를 이용해서 homophily한 버전의 인접행렬(adj_hp)과 heterophily한 버전의 인접행렬(adj_lp)를 얻을 수 있다. weight_to_adj라는 함수에서 이 동작이 수행된다. (코드는 github 참고)


Reference

논문 - https://arxiv.org/pdf/2211.14065.pdf

github - https://github.com/yixinliu233/GREET