본문 바로가기

Paper Review

[Paper Review] Differential Transformer (ICLR, 2025)

TL;DR: Attention noise를 줄이는 새로운 Differential Attention 방법론 제안

 

1. Introduction

Transformer는 불필요한 Context에 Over-attend한다

위 실험은 document 내에서 정답을 가져오는 Task에서 attention score를 추출한 결과이다. Tansformer는 answer 외의 부분에 너무 많은 attention score를 할당하고 있으며, 저자들은 이를 attention noise 라고 부른다.

 

이를 Differential attention이라는 연산을 통해 해결한다. 전자공학에서 두 signal의 간섭 (차이)을 이용하여 noise-canceling하는 것에 착안하여, 분포의 차이를 이용하여 attention noise를 줄이는 것이 핵심 아이디어

 

Differential attention에 대한 논문에서의 설명은 다음과 같다.

we partition the query and key vectors into two groups and compute two separate softmax attention maps. Then the result of subtracting these two maps is regarded as attention scores.

 

 

기존 Attention과의 차이점을 나타낸 pseudo-code는 다음과 같다.

 

2. Differential Transformer

 

2.1 Differential Attention

Given  

  1. hidden dimension of the model $d_{model}$
  2. Input $X \in \mathbb{R}^{N \times d_{model}}$
  3. $Q_1, Q_2, K_1, K_2 \in \mathbb{R}^{N \times d}, V \in \mathbb{R}^{N \times 2d}$

기존에 우리가 알고 있던 Attention과 비교하면 이렇게 다르다.

 

Scaled Dot-product Attention (in Transformer)

$Q = W^Q, K=XW^K, V=XW^V$

$\text{Attn}(X) = \text{softmax}(\frac{QK^T}{\sqrt{d}})V$

 

 

Differential Attention

$[Q_1;Q_2] = W^Q, [K_1,K_2]=XW^K, V=XW^V$

$\text{DiffAttn}(X) = (\text{softmax}(\frac{Q_1K_1^T}{\sqrt{d}})-\lambda\text{softmax}(\frac{Q_2K_2^T}{\sqrt{d}}))V$

 

 

이 때 $\lambda$ 또한 learnable scalar이며, 다음과 같이 re-parameterize 된다.

$\lambda = \exp(\lambda_{q_1} \cdot \lambda_{k_1}) - \exp(\lambda_{q_2} \cdot \lambda_{k_2}) + \lambda_{\text{init}}$

$\lambda_{q_1},\lambda_{k_1}, \lambda_{q_2}, \lambda_{k_2} \in \mathbb{R}^d$ 모두 learnable vector이다

 

$\lambda_{\text{init}}$은 실험을 통해 $0.8-0.6 \times \exp(-0.3 \cdot (l-1))$이 가장 잘 동작하는 것을 발견했다

(이거 어케 발견했지)

근데 그냥 0.8로 해도 꽤나 괜찮은 Initalize라고 함

 

Multi-Head Differential Attention

(h)가 attention head 개수라고 할 때,

projection matrices $W_i^Q,W_i^K,W_i^V,i \in [1,h]$는 각각 다른 matrix를 사용하고

The scalar $\lambda$는 헤드에서는 공유한다

 

Head output은 다음과 같이 normalize & project 된다

$\text{head} = \text{DiffAttn}(X;W_i^Q,W_i^K,W_i^V,\lambda)$

$\overline{\text{head}i} = (1-\lambda{\text{init}})\cdot \text{LN}(\text{head}_i)$

$\text{MultiHead}(X) = \text{Concat}(\overline{\text{head}_1}, \cdots, \overline{\text{head}_h} )W^O$

 

이 때 $W^O \in \mathbb{R}^{d_{model} \times d_{model}}$는 learnable projection matrix이고,

Concat은 channel dimension에서 이루어진다.

$h = d_{model} / 2d$로 정했다

 

Headwise Normalization

$\text{LN}(\cdot)$은 $\text{GroupNorm}(\cdot)$로 표현하였으며, 이는 각 head에 독립적으로 적용된다는 뜻이다.

Differential attention이 기존보다 더 sparser pattern을 갖고, 더 diverse statistical information을 갖기 때문에 이렇게 적용했다.

실제로는 RMSNorm을 사용하였다.

 

2.2 Overall Architecture

설명의 편의성을 위해 Decoder-only architecture를 가정했다

초기 $X^0 =[x_1, \cdots,x_N]\in \mathbb{R}^{N \times d_{model}}$가 임베딩되었을 때,

이 $X^0$는 다음과 같이 $X^L$로 contextualize된다

$Y^l = \text{MultiHead(LN}(X^l)) + X^l$

$X^{l+1} = \text{SwiGLU(LN}(Y^l)) + Y^l$

이 때 $\text{SwiGLU}(X) = (\text{swish}(XW^G) \odot XW_1)W_2$이고,

$W^G, W_1 \in \mathbb{R}^{d_{model} \times \frac{8}{3}d_{model}}, W_2 \in \mathbb{R}^{\frac{8}{3}d_{model} \times d_{model}}$인 learable matrices이다.

 

3. Experiments

이렇게 구축한 Differential Transformer를 large language model화 해서 실험하였다.

각 sub-section마다 실험 세팅이 다른데, 모두 소개하기 제한되기 때문에 자세한 실험 세팅이 궁금한 분들은 논문 참고

 

3.1 Language Modeling Evaluation

  • 다양한 downstream tasks에서 기존 well-trained Transformer-based model들과 성능을 비교한 부분이다
  • 실험은 아키텍처 구성 (hidden size, # of layers, head dimension)부터 train 파라미터 (sequence length, batch size, train tokens, optimizer 세팅) 까지 fair comparison을 목표로 설정하였다.
  • LM Eval Harness benchmark로 실험한 결과, favorable performance를 기록했다!

 

3.2 Scalability Compared with Transformer

DIFF Transformer의 Scalability는 어떨까??

 

  • LLaMA와 같은 세팅을 공유하며, model size와 training tokens를 점점 늘렸을 때의 결과를 측정하였다.
  • 더 적은 파라미터와 더 적은 tokens수로도 충분한 scalability를 기록하였다

 

3.3 Long-Context Evaluation

  • 3B-size 모델의 Context length를 64K까지 확장했을 때의 결과를 측정한 부분
  • 기존 Transformer와 유사한 양상을 보이지만, increasing context에서 더욱 효과적인 결과를 보여주었다. (더 낮은 NLL을 기록하였다)

 

3.4 Key Information Retrieval

Large context에 임베딩된 critical information을 잘 추출할 수 있는지에 대한 부분

Intro에서 주장한 Attention noise의 영향이 가장 잘 드러나는 핵심 실험이기도 하다

Retrieve from 4K Context Length

  • 4K 길이의 Context에 N개의 needles (concise sentence)가 있고 R개의 query를 물어봤을 때 잘 Retrieve하는 지 실험하였다
  • DIFF가 훨씬 잘했음!!

Retrieve from 64K Context Length

N=8, R=1로 고정한 상태로 Context Length를 64K까지 늘렸을 땐 어떨까?

  • Transformer는 같은 query라도 depth에 따라서 다른 성능을 보이지만, DIFF는 그에비해 stable한 결과를 보여주었다.

Attention Score Analysis

  • Normalized attention score를 통해 attention noise를 측정한 결과이다
  • DIFF는 answer에 더 많은 attention score를 할당하고 있고, 특히 Attention noise를 줄이는 데에 매우 효과적이다.

 

3.5 In-Context Learning

In-context learning은 language model의 fundamental capability 중 하나이다.

이 section에서는 이 능력을 1) many-shot classification과 2) robustness 측면에서 검증하였다

Many-Shot In-Context Learning

  • demonstration sample을 1-shot부터 64K lengths에 도달할 때 까지 제공해보았다.
  • DIFF는 5.2%부터 21.6%까지 substantial improvement를 보여주었다.

Robustness of In-Context Learning

  • Transformer는 demonstration의 order permutation에 취약하다고 알려져 있다. (Reranking model을 쓰는 이유기도 하다)
  • 그에 반해 DIFF는 순서에 대해 굉장히 Robust한 결과를 보여주고 있다

 

3.6 Contextual Hallucination Evaluation

Text summarization & Question answering에서 hallucination이 Transformer보다 적게 일어난다를 주장하는 부분

Input context는 항상 참인 경우만 고려하였다

  • GPT-4o에게 free of hallucination인지를 binary judgement로 물어보는 evaluation protocol을 적용했다
    • GPT-4o도 LLM인데, halluciation을 평가할 수 있나..?라고 생각했지만, 선행 연구에서 이래도 괜찮다고 한다

 

3.7 Activation Outliers Analysis

Layer의 특정 부분이 과도하게 activate되어 결과를 내는 현상을 Activation Outlier라고 부른다

이는 model quantization에 큰 어려움을 야기하고, LLM이 풀어야 할 골칫거리 중 하나이다

Diff는 Transformer보다 이게 훨씬 적게 일어남을 보이는 부분

 

이를 바탕으로 실제로 Quantize했을 때에도 Transformer보다 성능 감소율이 훨씬 적었다

  • Attention noise를 줄임으로써 자연스레 얻게 되는 이점 중 하나인 것 같다

 

3.8 Ablation Study

  • GroupNorm 자체로도 성능 향상을 일으킨다
  • 복잡한 $\lambda_{\text{init}} = 0.8-0.6 \times \exp(-0.3 \cdot (l-1))$ 이 가장 높은 성능을 기록하긴 했으나, 그냥 $\lambda_{\text{init}} = 0.8$로 해도 꽤나 robust하다

 

References