본문 바로가기
공부/실전문제연구단

[실전문제연구단] GREET 코드 뜯어보기 - 2. GCL 모델

by 박영귤 2024. 1. 12.

코드에 모델이 두 개 있다. 하나는 GCL, 하나는 Edge_Discriminator이다. 이 게시글에서는 GCL모델이 무엇인지 설명할 것이다.

GCL은 Graph Contrastive Learning의 약자로, 비교하는 방식을 학습시키는 모델이다. 간단히 말해서 두 입력의 차이를 계산한다고 할 수 있다.

논문에서 소개한  Dual-channel Representation Learning Module이다.


데이터셋

Nodes: 2708
Edges: 10556


class GCL(nn.Module):
    def __init__(self, nlayers, nlayers_proj, in_dim, emb_dim, proj_dim, dropout, sparse, batch_size):
        super(GCL, self).__init__()

        # 그래프 합성곱 레이어를 사용하여 노드의 임베딩을 생성하는 SGC 모듈 생성
        self.encoder1 = SGC(nlayers, in_dim, emb_dim, dropout, sparse)
        self.encoder2 = SGC(nlayers, in_dim, emb_dim, dropout, sparse)

        # 프로젝션 헤드(Projection Head) 레이어 정의
        if nlayers_proj == 1:
            self.proj_head1 = Sequential(Linear(emb_dim, proj_dim))
            self.proj_head2 = Sequential(Linear(emb_dim, proj_dim))
        elif nlayers_proj == 2:
            self.proj_head1 = Sequential(Linear(emb_dim, proj_dim), ReLU(inplace=True), Linear(proj_dim, proj_dim))
            self.proj_head2 = Sequential(Linear(emb_dim, proj_dim), ReLU(inplace=True), Linear(proj_dim, proj_dim))

        # 배치 크기 저장
        self.batch_size = batch_size​

모델의 구성요소로는 SGC(Simplifying Graph Convolutional Networks) 두 개와, 프로젝션 헤드 레이어 2개가 있다. 프로젝션 헤드는 다차원을 저차원으로 축소시켜주는 선형레이어이다.

이 모델이 하는 역할은 두 그래프로 노드임베딩을 구하고, 그 임베딩 사이의 차이를 이용해 loss를 계산하는 방식이다.


    def forward(self, x1, a1, x2, a2):
        # 두 그래프에 대해 그래프 합성곱과 프로젝션을 수행하고 Contrastive Learning 손실 계산하여 반환
        emb1 = self.encoder1(x1, a1)
        emb2 = self.encoder2(x2, a2)
        # 다차원을 저차원으로 축소시킴. (type : Tensor)
        proj1 = self.proj_head1(emb1)
        proj2 = self.proj_head2(emb2)
        # Contrastive Learning 손실 계산
        loss = self.batch_nce_loss(proj1, proj2)
        return loss

forward함수는 다음과 같다. 즉 두 그래프를 입력시켜 노드임베딩을 반환한다. (emb1.shape : 2708*128) 그 후, 프로젝션 레이어를 통해 크기를 축소시키고 loss를 계산해준다. 이 때 contrastive learning을 사용한다.

그렇다면 Contrastive Loss란 무엇일까?

  • Contrastive Loss: Contrastive refers to the fact that these losses are computed contrasting two or more data points representations. This name is often used for Pairwise Ranking Loss, but I’ve never seen using it in a setup with triplets.

두 개 이상의 데이터 표현을 대조하여 계산하는 loss이다.

https://na1-4an.tistory.com/106 

팀원인 나은이의 게시물인데, 이 곳에 더 자세히 나와있다.


 

이 모델에서 사용하는 loss는  InfoNCE Loss(Information Noise Contrastive Estimation)이다.

이를 자세히 설명하려면 글이 길어질 것 같으니 간단히만 작성하겠다.

Word2Vec에서는 한 word를 vector로 바꾸는 방법에 대해 소개한다. 이 때, 문맥을 반영해서 vector로 변형시킨다. 여기서 context(문맥)은 옆 단어를 말한다. 예를 들어 "빠른 주황색 여우가 점프를 한다"라는 문장에서 "주황색"이라는 단어의 문맥은 "빠른", "여우가"이다. 이 때, 관련이 없는 문맥까지 적용해버릴 수도 있다. 이를 방지하기 위해 나온 것이 Negative Sampling이다. 고급진(?) 메커니즘을 통해 Negative한지 Positive한지 파악한 후에, positive한 확률에서 negative한 확률을 뺀 것을 loss로 설정하였다.

noise contrastive estimation

이것을 NCE(noise contrastive estimation)라고 부른다.

너무 어렵네요 ㅠ 이해한 것이 틀렸을 수 있습니다.


Reference

논문 - https://arxiv.org/pdf/2211.14065.pdf

github - https://github.com/yixinliu233/GREET    

블로그 - https://junia3.github.io/blog/InfoNCEpaper