概要
SB IntuitionsのSeng Pei Liew、李凌寒、高瀬翔です。
弊社では日本語能力に主眼を置いた大規模言語モデルの構築に取り組んでおり、パラメータの大規模化のための施策として、学習済みの70Bパラメータ*1をMixture-of-Experts(MoE)モデルに拡張し、事前学習を行ったモデルをSarashina2-8x70Bとして公開しました。
言語モデルの100Bパラメータ以上までの大規模化はまだ弊社でも試行錯誤の段階にあり、引き続き大規模なモデルの学習を行っていますが、 本記事では先日公開したSarashina2-8x70Bの性能やMoE、学習済モデルのMoEモデルへの拡張について紹介したいと思います。
Mixture-of-Expertsについて
Sarashina2-8x70BモデルではTransformerにMoE層を組み込んだアーキテクチャを採用しています。 通常のTransformerアーキテクチャとの対比を下記の図に示します。
図の左側、通常のTransformerアーキテクチャでは各層はSelf-attention層とFFN層の組み合わせで構成されます。 以降、これをMoEと対比する目的でDenseと呼びます。 これに対し、MoEを組み込んだTransformerは、図の右側にあるように、FFN層がMoE層に置換されています。
下記の図にMoE層の内部を詳しく示しました。
図にあるように、MoE層は複数のFFN層と、Routerからなります。 MoEはMixture-of-Expertsという名前に表されているように、複数のエキスパートの中から入力に適したエキスパートを(複数)選択し、そのエキスパート(もしくはエキスパート群)で出力を行うモデルです。 この図では、FFN層がエキスパート、Routerが入力に適したエキスパートを選択する部位となっており、FFN層が4つ、すなわち、4つのエキスパートから入力に適したエキスパートを用いるモデルとなります。 Routerは入力に対して各エキスパートがどの程度適切かの確率を付与し、確率上位k個のエキスパートを用いて出力を行います。 この図では、FFN1、FFN2、FFN3、FFN4に、それぞれ0.5、0.4、0.07、0.03という確率が付与されています。 この図ではk=2、すなわち、4つのエキスパートから入力に適した上位2つのエキスパートを選択して出力を行うことを想定しており、FFN1、FFN2は出力の計算に参加しますが、FFN3、FFN4は除外されます。
このように、全エキスパートの中から各入力に対して適したエキスパートを選択することで、全パラメータ数分の表現力を保持しつつ、各入力に対する計算を軽量化したモデルがMoEです。 なお、各入力に対する計算に利用しているパラメータをアクティブパラメータと呼びます。 Sarashina2-8x70Bはモデル全体のパラメータ数は460Bですが、各MoE層では8つのエキスパート中、入力に適した上位2つのエキスパートを計算に用いるモデルとなっており、アクティブパラメータ数は130Bとなっています*2。
また、MoEを用いたモデルでは各MoE層のエキスパート数×ベースのTransformerアーキテクチャのパラメータ数という記法が採用されることが多いです。 すなわち、8x70Bという表記は、基本的には70BパラメータのTransformerアーキテクチャであり、各FFN層がMoE層に置き換わっている、さらに、各MoE層は8つのエキスパートを持つ、ということを示します。
モデル構築の戦略
MoEモデルは各入力に対して全パラメータを計算に用いる訳ではないという性質上、全パラメータ数が同じDenseモデルよりは高速に学習が可能ですが、前述のとおり、Sarashina2-8x70Bはアクティブパラメータ数も130Bと非常に大規模なモデルとなっており、学習にも多大な計算資源を必要とします。 学習を少しでも効率化するために、我々は学習済みの70Bモデルを活用することを考えました。
図1に示したように、Sarashina2-8x70BのアーキテクチャはMoE層以外はDenseモデルと共通であり、また、図2に示したように、MoE層内部もDenseモデルと同様のFFN層を複数個持っているだけです。 このことから、MoE層以外は学習済みDenseモデルのパラメータをそのまま転用し、MoE層についてはDenseモデルのFFN層のパラメータをエキスパートの個数分複製することで、Denseモデルと同等の性能のMoEモデルを構築することができます。 この状態を初期値として学習を開始することで、MoEモデル構築の効率化を図りました。 この手続きはSparse Upcyclingと呼ばれています。 また、この手続きで学習を行ったため、下記の表に示したように、Sarashina2-8x70Bは次元数や層数など多くのハイパーパラメータがSarashina2-70Bと共通の値となっています。
モデル | Sarashina2-70B | Sarashina2-8x70B |
---|---|---|
層数 | 80 | 80 |
各層の次元数 | 8192 | 8192 |
FFN層の中間層の次元数 | 28672 | 28672 |
Self-attentionのヘッド数 | 64 | 64 |
各MoE層内のエキスパート数*3 | - | 8 |
各入力に対して使用するエキスパート数 | - | 2 |
総パラメータ数 | 70B | 460B |
アクティブパラメータ数 | 70B | 130B |
学習率に関しては8x1B、8x7Bのモデルで実験を行い、
- 大きな値を設定した場合、学習済みモデルのパラメータから乖離してしまうからか、初期に損失値が増大してしまうこと*4
- 小さな値の場合には初期の損失値の増大は防げる一方、長期間学習した際の性能向上は少ないこと
を確認しました。 特に、学習率に大きな値を設定した場合の、損失値の増大はパラメータ数が大きくなるほど顕著でした。 今回の学習では、Upcyclingしたパラメータを毀損しないように元のモデルであるSarashina2-70Bの学習終了時の学習率を採用しました。 具体的には、最大学習率 学習終了時の学習率はそれぞれ
- Sarashina2-70B:
- Sarashina2-8x70B:
という値になっています。
性能
学習を終えたモデルの性能を評価するために、以前公開したSarashina2-70B の性能で用いた日本語QAタスクのスコアを掲載します。 Sarashina2-8x70Bモデルのアクティブパラメータは130Bであるため、100B以上のパラメータを持つ、フルスクラッチで学習された日本語モデルとの性能比較も掲載します。
モデル | AI王 | abc | JEMHopQA | NIILC | JComQA | JSQuAD |
---|---|---|---|---|---|---|
Sarashina2-70B | 89.20 | 91.20 | 82.91 | 68.50 | 94.55 | 88.45 |
Sarashina2-8x70B | 90.90 | 90.67 | 82.05 | 65.35 | 95.26 | 91.04 |
stockmark-100B | 68.10 | 53.06 | 39.32 | 46.46 | 40.04 | 51.80 |
Plamo-100B | 83.00 | 85.57 | 64.96 | 70.08 | 93.48 | 89.73 |
LLM-JP-3-172B-β1 | 79.70 | 80.68 | 53.85 | 60.63 | 86.24 | 84.38 |
LLM-JP-3-172B-β2 | 83.60 | 84.84 | 57.26 | 68.50 | 91.42 | 87.35 |
まず、Sarashina2-8x70Bは公開されている 100B以上の日本語モデルと比較しても、全体的に良好なスコアを出しており、特に AI王、JComQA、JSQuAD においては最高のスコアを出しています。
一方で、Sparse UpcyclingのベースとなったSarashina2-70Bと比べた場合、一部のタスク(AI王、JComQA、JSQuAD)に関してはスコアが向上する一方で、他のタスク(abc、NIILC)ではスコアが低下しています。 この理由に関しては、Sarashina2-8x70Bの学習に使用していたコーパスが、Sarashina2-70Bの学習と同じものであったため、新しい知識が得られる余地が限られていたことが考えられます。 または、日本語QAタスクにおいては、70Bの時点で多くのタスクで8割、9割超えのスコアを出しており、その中でさらにスコアを向上させることが難しいという側面も考えられます。
おわりに
本記事ではMoEを用いた大規模言語モデルであるSarashina2-8x70Bの構築戦略や性能を紹介しました。 SB Intuitionsでは引き続き日本語大規模言語モデルの開発に取り組んでいきます。
*1:BはBillionの略で10億ですので、70Bパラメータは700億パラメータとなります。本記事では表記としてBを採用します。
*2:本記事でのパラメータ数計算において、Embedding層および、確率分布を計算する重み行列(LMHead層)は除外しています。
*3:Sarashina2-70BはMoE層を持っていないため数を表記していませんが、FFN層をエキスパートとみなし、エキスパートを1つのみ持つという解釈も可能です。
*4:warmupで学習率を上げるのに従ってゆっくりと上がっていく現象で、いわゆるLoss spikeとは異なると考えられます。