ELECTRA 이해하기



ELECTRA란?

ELECTRA는 이전 포스팅 “Introduce to BERT” 에서 다룬 BERT를 기반으로 한 모델입니다.
“Efficiently Learning an Encoder that Classifies Token Replacements Accurately” 라는 의미를 가지며, 말 그대로 대체된 Token을 정확하게 분류하는 Task로 만들어진 모델입니다.

gan_img_1

특히 ELECTRA는 기존 GAN(Generative Adversarial Network)의 Generator와 Discriminator의 아이디어를 활용했다고 합니다.

GAN은 Noise Vector를 입력받는 Generator를 통해 가짜 이미지를 생성하고, Discriminator는 진짜 이미지 데이터와 Generator를 통해 만들어진 가짜 이미지를 분별하는 Network입니다. 이 둘은 서로를 적대적으로 (Adversarial) 학습합니다.

  • Generator : Generator를 통해 만들어진 가짜 이미지로 Discriminator가 속게 학습함
  • Discriminator가 : Generator를 통해 만들어진 가짜 이미지를 잘 분류할 수 있게 학습함.

하지만 ELECTRA는 Generator와 Discriminator가 Adversarial하게 학습하는 것이 아닌, likehood하게 학습한다는 차이점이 있습니다. 밑에서 ELECTRA에 대한 설명을 시작하겠습니다.


Replaced Token Detection (RTD)

electra_img_2

위 사진은 Generator와 Discriminator로 이뤄진 ELECTRA의 Architecture이며, 사진을 참고하여 pre-trained과정을 간단히 설명하겠습니다.

[동작과정]

“the”, “chef”, “cooked”, “the”, “meal”

위 5개의 단어들은 원본 문장을 tokenize 된 token입니다. 이 단어들은 Generator에 들어가 MLM Task를 통해 학습힙니다.
때문에 위 5개의 token에 랜덤하게 [MASK]를 취해줍니다.

“[MASK]”, “chef”, “[MASK]”, “the”, “meal”

위와 같이 MLM Task의 input으로 적합한 token들이 나왔습니다.
이 token들은 Generator의 input으로 들어가 MLM Task를 수행하여 각각의 [MASK]를 예측합니다.

“the”, “chef”, “ate”, “the”, “meal”

“the”라는 token은 올바르게 “the”로 예측을 했지만, “cooked”는 “ate”로 틀리게 예측을 했으며, 이 token들은 곧장 Discriminator의 input으로 들어가 학습하게됩니다.

Discriminator는 input으로 들어온 token들 중에서 원본 데이터가 대체되었는지 아닌지를 (original/replaced) 분류하며 학습을 진행합니다.
이 Task를 “Replaced Token Detection (RTD)” 라 합니다.

Mechanism

1) Generator

입력 벡터 $\boldsymbol x = [x_1, …, x_n]$ 중에서 랜덤하게 Masking할 단어들을 고릅니다.
(Masking된 단어들 : $\boldsymbol m = [m_1, …, m_k]$)

$\boldsymbol x^{masked}$ = REPLACE($\boldsymbol x, \boldsymbol m$, [MASK])

그 후 위 식을 통해 마스킹을 완료합니다.

electra_img_3

그리고 입력 벡터 x = $[x_1, …, x_n]$가 문맥적 의미를 가질 수 있게 $h$(x) = $[h_1, …, h_n]$ 문맥을 갖는 벡터로 변화해준 후, 위 식 Generator에 input으로 넣은 후 예측합니다. (t번째의 $\ x$에 대한 확률 값을 반환함)

2) Discriminator

처음 입력 벡터 x = $[x_1, …, x_n]$에, Masking되고 Generator거친 결과 값을 대치(Replace) 합니다.

$\boldsymbol x^{corrupt}$ = REPLACE($\boldsymbol x, \boldsymbol m, \boldsymbol {\hat x}$)

여기서 $\boldsymbol {\hat x}$ 벡터 값은 Generator에서 예측한 결과 ($P_G(x_t$ㅣ$\boldsymbol x)$) 입니다.

electra_img_4

그 후 Discriminator에 $\boldsymbol x$ 벡터를 넣으면, $h_D(\boldsymbol x)_t$를 거치고 가중치와 곱해진 후, Sigmoid layer를 통해 0~1사이의 값을 갖습니다.
0~1사이의 값은 이 단어가 Replace된 단어인지 아닌지를 예측하며, 학습해 나갑니다.

Loss Function

이제 위 Generator와 Discriminator가 어떤식으로 학습을 진행하지는 이야기해보겠습니다.

electra_img_7

그 후, 위 식($L_{MLM}$)을 통해 Generator의 loss 값을 구하고,
아래 식($L_{Disc}$)으로 Discriminator의 loss값을 구합니다.

electra_img_8

이후 두 식을 더해서 minimization 하는 방향을 Backpropagation을 진행합니다.


Performance

electra_img_1 위 사진을 보시면, GLUE 성적은 BERT모델보다 5.0 이상 좋은 성적을 높다는 것을 알 수 있습니다.
또한 타 모델에 비해 적은 연산량(FLOPS)으로 좋은 performance를 보여주는 모델임을 알 수 있습니다.
이상으로 ELECTRA모델에 대한 소개를 마치겠습니다.

Reference
  • https://arxiv.org/abs/2003.10555
  • https://christinakouridi.blog/2019/07/09/vanilla-gan-numpy/
  • https://www.onemathematicalcat.org/MathJaxDocumentation/MathJaxKorean/TeXSyntax_ko.html