このページの目的
CAMEは学習に使用するoptimizerです。
技術的にもコスパ・タイパ面で非常に優れている反面、活用する上で無視できない問題点が存在します。本来の実力の半分も発揮できないほどです。
それではもったいないので専用ページを作成しました。
特徴
optimizerの一種
- 基本情報はオプティマイザーページを参照
原理
update計算式
現在のstepにおける更新量updateの値は、テンソルの要素毎に下記の計算が行われる。
update =(AdamW, AdaFactorにおけるupdate) * res_approrx
res_approrxは信頼度係数で、多くの場合0~2.0程度の値になります。
res_approx計算(近似式)
res_approx ≒ 1/√res
*計算の過程で行列分解を仲介するので厳密な数式は書けないため、この数式は大まかな傾向を表したものになります。
res計算
res = (現在のstepのupdate - updateの指数移動平均)^2 = (update - exp_avg)^2
updateはgradから計算された瞬間的な増加量(暫定値)です。
下図はそれを表したグラフです。
過去のstepの指数移動平均exp_avgとupdateが乖離するほど不信頼度resが大きくなる仕組みです。
update瞬時値を、大きく3パターンに分けて考えてみましょう。
- 条件1:updateがexp_avgと同符号かつ、同等
- 条件2:updateがexp_avgと同符号だが、とても大きい
- 条件3:updateがexp_avgと異符号
各条件におけるres = (update - exp_avg)^2の計算結果は、
条件1のときres小、条件2と3のときres大になります。
異符号の場合は、問答無用で大きくなるのがポイント
実際には、resはそのstepで即座には反応せず、指数移動平均として後のstepに少しずつ作用します。
そのため、resの逆数は擬似的な学習率倍率として機能します。
結果として、
- updateとexp_avgが違うほどresが下がり、ブレーキがかかる
- updateが振動が多い場合、長いstep期間でresが低下し、強いブレーキがかかる
それによって下記が抑制される。
- 学習の主流と違う要素。ほどほどに抑制される
- 画像のJPEGノイズ
- ランダムな背景要素
なお、ポジティブな情報、新しい情報(タグ等)については、
resの長い指数移動平均の立ち上がり期間内に取得できるため、学習は抑制されにくい。
補足
- 異符号計算をすることが目的なのであれば、sign関数などを用いれば簡単なのですが、
(update-exp_avg)^2という数式はより計算回数を減らしたシンプルな設計になっています。- 逆に言えば、少し計算コストを増やしても良いのであれば、まだまだ改良の余地があります
resの効果を別目線で...
- これはあくまでもイメージ図ですが、weight平面があるときに、
理想的なweight変化の軌道があって、なるべくその軌道に乗せたいという状況があると仮定します。 - その軌道から離れた所から、あるstepでその軌道を飛び越えてオーバーシュートし始める際に、
resがなければ、強いオーバーシュートが発生するところを、
resがあれば、grad反転→res急増→update量が急激に低下するため、ベクトルの向きが鋭く変化し、オーバーシュートを防ぎやすくなります。
- その特性こそが、「CAMEの収束は早い」と評価される要因の一つです。
- 露悪的に言えば、オーバーシュートがないことが多様性を阻んだり、局所解の脱出を遅らせるというデメリットもあります。
設定値
基本的にはデフォルト設定(無設定)で問題ありません。
それ以外のパラメータは個人で判断するべきですので、おすすめ値などは紹介しません。
ただし、betasだけは初期値だと問題があるため、簡単に指針を示します。
詳しい内容を論文を読んで下さい。*1
betas
指数移動平均(EMA)の慣性力の強さを決めるパラメータ。
betasのデフォルト値は、
betas=[0.9, 0.999, 0.9999]
beta[2]=0.9999が大きすぎる
betasのデフォルト値は、当wiki読者の使用条件との相性が悪いです。
20000step以上学習する予定がなければ、下記と同等以下にした方がいいでしょう。
betas=[0.9, 0.99, 0.99]
■理由
- beta[2]=0.9999と高いということは、
1stepあたりのresの影響度は1/10000しかありません。
resの値が正確に反映されるまでに最悪10000step以上かかります*2。
それ以下のユーザーは効果を体感しにくくなります。 - その特性は、モデルをちょっと変更したいユーザーとのニーズと合致していません。
どちらかというと、長いstepをかけてモデルを0から学習するユーザー向けです。しかし、長step実行する場合でも、そもそもスケーリングやresは行列分解を使用していて精度が低いので、それを高いbetaで信頼性を上げても、最大限機能を発揮できるのかという疑問があります。 - 具体的には、
学習中に、修正不能なほどモデルを破壊するオーバーシュートがゆっくりかかる感じがする場合は、下げた方がいいでしょう。 - 同様に、
beta[1]=0.999もそこそこ高いので0.99以下にした方がいいでしょう。
beta[0]=0.9は修正不要です
clip_threshold
デフォルトの1.0で問題ありません。
- レイヤー内のRMSを計算したときに、clip_thresholdよりも大きいパラメータの更新を弱める。
スパイク状の異常値が発生したときそれを弱めるのが目的であるが、
その分、新しい情報を削ってしまうというデメリットがある。 - 実用上は、スパイク状の異常値が発生しても、下流の計算で平均化処理がなされるためあまり重要ではない。
- clip_thresholdを変える副作用として、更新量そのものも変化するため、Lrを変えるのと同じ効果が発生する。
eps
eps=(1e-30, 1e-16)
- 計算中における各パラメータの最小値です。
CAMEの場合は、ゼロ除算対策に使用されます。また値が小さいと、その分updateの最小値も下がります。
dtype(fp16,fp8等)との相性問題があり、学習が不安定な場合、または安定し過ぎて進まない場合は少し調整する必要があります。
- eps[0]=1e-30の値については、
fp16使用時のアンダーフローや、過剰なフィルタリングを感じる等の理由で学習が進まない場合は、1e-30~1e-12程度に増加してもいいです。ただし、小さいgradに対しても更新が発生するため、不安定になります。
小さ過ぎるepsはゼロ除算を抑制できず、過剰なupdateが発生する可能性があります。
- デフォルト値で問題ありません。
- fp16やfp8のような低精度型の場合は更新が0にならないよう慎重に調整するべきです。
特に、学習中に唐突な振動が発生した場合は、このパラメータの過大または過小を疑った方がいいでしょう。
weight_decay
過適合対策用のパラメータ。
p.data(weight)が過剰に大きくなって、そのパラメータが支配的になりすぎないようにする。
p.dataが大きいほどその値を0に向けたupdateが発生する
AdamW等でも存在している独立した数式。
デフォルトの0で問題ありません。
weightの変化が大きいのが気に入らない場合は多少上げる
ただし、計算時間が増えます。
ちなみに、weight decay以外の設計改善やモデル構造の改善によって、2026年以降ではあまり利用価値はありません。過適合対策が目的ならば、正則化画像を増やす方がよほど効果があります。
問題点・バグ
既知の問題点とその対策法を紹介します。
注意:
コーディング上のバグ
主に最適化不足の部分です。
無駄なres計算がある
- else条件ではresを使用しないのに、res計算をしているのは無駄である。
res = (update - exp_avg)**2 + group["eps"][1] if factored: ... .. . update = res_approx.mul_(exp_avg) else: update = exp_avg.clone()
- 対策
その部分をif文内に移動するだけでOKif factored: res = (update - exp_avg)**2 + group["eps"][1] ... .. . update = res_approx.mul_(exp_avg) else: update = exp_avg.clone()
- これで計算速度、VRAM消費が改善します。
ノーリスクなので必ず実施しましょう!
CAME設計の問題点
update, exp_avgのオーダーが考慮できていない
exp_avgの値が大きいと、それだけでresの値が大きくなり、強いブレーキが発生する。
本来積極的かつ健全に学習が進むレイヤーであっても、更新が限りなく0に近づき、ウェイトが小さく収束してしまう。
具体的には、weightが初期値あるいは1e-10オーダーよりも大きくなるのが困難になります。*一般的にもっと高い値で収束するのが普通です
- ウェイトが小さくなるということは、モデルの収束のしやすさという点で擬似的なweight_decayとしても作用するという点ではメリットだが、
よりパラメータをアクティブに活用したいユーザーにとっては、AdamWよりもこじんまりとした結果になる - 特に、行列分解によるパラメータロック(後述)と相性が最悪であり、局所解の脱出ができなくなる
明らかに目標と異なるのに、学習が進まなくなる原因の1つ
- 対策:exp_avg(1次モーメント)で除算して正規化します。
res = 解離量L2 / 変化量慣性
= (update - exp_avg)^2 / (mean(exp_avg) + eps)
# 変更前
res = (update - exp_avg)**2 + group["eps"][1]
# 変更後:
res = update.sub(exp_avg).pow_(2)
denom = exp_avg.pow(2).mean().add_(group["eps"][1])
res.div_(denom)
計算時間増加を最小限にするため、exp_avgの統計値だけに計算を限定しています。
更新量が増えるため、少し不安定になるかもしれません。
resの初期値が小さすぎる
resの初期値は0です。
学習初期に、res_approx≒∞になり強烈なupdateが発生し、loss=NaNの危険性が極めて高いです。
対策として1に初期化しましょう。
#変更前
state["exp_avg_res_row"] = torch.zeros(grad_shape[:-1]).type_as(grad)
state["exp_avg_res_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).type_as(grad)
# 変更後
state["exp_avg_res_row"] = torch.ones(grad_shape[:-1]).type_as(grad)
state["exp_avg_res_col"] = torch.ones(grad_shape[:-2] + grad_shape[-1:]).type_as(grad)
ただし、発散覚悟でとにかく最短で収束させたいという場合は、初期値でも良いと思います。
resの上限がなく、ブレーキが強すぎる。
- resに強烈なスパイクが発生したときに、
resの指数移動平均が高くなりすぎて復帰できなくなる。 - 結果として、該当するレイヤーが固着する。
固着している間にも、他のパラメータの更新はどんどん進む。
他のパラメータは、固着しているパラメータの存在を前提に更新が進むため、バランスが壊れる。 - 明らかに目標と異なるのに、学習が進まなくなる原因の1つ
- 対策
res計算の最後に下記を追加して下さい。
resの最大値を制限します。res = res.clamp_(max=2e+2) - res=200の上限設定は、res_approx=1/√200=0.07
すなわち、体感上、Lrを0.07倍したときよりも小さくさせないと言う意味です。これにより完全なロックを防ぐことができます。
resの下限がなく、アクセルが強すぎる
- resの下限値がないため、res_approxが過大になり、updateが過大になる
- これについては、結果としてresが大きくなる方向に働くので放置しても良い。
あまりbeta[2]を上げすぎると、アクセルからブレーキへの切り替えが遅く破綻するリスクがあるので注意しておいたほうがいいだろう - 対策(非推奨)
どうしても気になる場合のみ,res_approx計算の最後に上限値をかけるといい。
計算時間がほんの少し伸びますが(0.1sec/step未満)、そもそも恩恵を感じにくい変更ですres_approx.clamp_(max=4.0) # この行を追加 update = res_approx.mul_(exp_avg) - res_approxをmax=4.0で制限するという操作は、そのパラメータの部分だけ、Lrを4倍にしたときと同じラインで制限するという意味です。
元々Lrを小さく設定している人にとっては、制限しなくてもいいでしょう。そのままにしたほうが局所解脱出に貢献できるので
行列分解による精度面の限界
- これは不具合ではなく仕様です。
- CAMEは、スケーリング*3とresに行列分解を使用しているため、計算コスト削減と引き換えに精度が大きく低下します。
- 巨大な行列を、最初の行、最初の列だけで近似しているというのが問題です。
(問題というか、そこまで正確である必要はないという証左でもあることの裏返しです) - 問題となる区間はここ。
その部分は行列分解をしないならば、たったこれだけのシンプルな数式です。
exp_avg_res_row = state["exp_avg_res_row"] exp_avg_res_col = state["exp_avg_res_col"] exp_avg_res_row.mul_(group["betas"][2]).add_( res.mean(dim=-1), alpha=1.0 - group["betas"][2] ) exp_avg_res_col.mul_(group["betas"][2]).add_( res.mean(dim=-2), alpha=1.0 - group["betas"][2] ) res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col)exp_avg_res = state["exp_avg_res"] exp_avg_res.mul_(group["betas"][2]).add_( res, alpha=1.0 - group["betas"][2] ) res_approx = exp_avg_res.rsqrt()
res_approx単体、update*res_approxによって更新量が減るという機能だけで見れば、長時間学習すればいずれ理想的な場所に収束するはず。
しかし、そこに行列分解が存在することで、各テンソル要素にとっては不都合な、大きなブレーキが発生します。
特に顕在化するのは、レイヤー内部のパラメータが複雑な分布な場合。例えば、
- TextEncoderのように、特定のトークンが、レイヤー内に局所的に存在している場合。
- 最新モデルのDiT(AnimaのTransofomer等)のように、パラメータがレイヤーを超えてより複雑に絡み合っている場合
はっきり言えば、SDXLではあまり問題ないが、Animaでは顕在化する
その結果、
- 目に見えて、プロンプト追従性の低下や画像が溶けたようになる
- 体感上の更新量が大きく抑えられるため、パラメータの局所解脱出ができなくなり、学びたい要素が学べなくなる。
理論上、AdamW8bitよりも精度面で劣る可能性があります。
対策としては、
- VRAMや学習時間の微増を許容できる人は
Gemini等のChatAIに依頼して、行列分解を除去してもらいましょう。
数行ちょちょいと直すだけの、AIにとっては簡単な作業です。
明らかにCAMEを尊重しない新規のコードを開発する行為になるので、当wikiでの紹介は控えます。 - resとスケーリング両方の行列分解を解除する必要はありません。
- 計算順序として最後の砦にあたるresの行列分解を解除するだけで十分効果はあります。
- それにスケーリング区間においては、eps[0]をそこそこ上げておけばロックを抑制でき、gradでレイヤー行列内のコントラストは十分稼げるので、割となんとかなる。

