LoRA/学習方法/SDXL

Last-modified: 2024-03-24 (日) 21:29:52

SDXLモデルを使用した学習の方法です。
LoRA/学習方法ページの1.5系モデル向けLoRA学習ができる事を前提に説明します。

十分な検証報告が上がっていないため、下記の情報は不正確です。
試しに学習したいと思ったときの、メモ程度にご覧ください。

前提条件

はじめに。

  • SDXLでの学習は、v1.5系よりも大きな計算負荷がかかります。
    sd-scriptsではVRAM消費削減のために、text encoderを学習対象から外すなど、
    SDXL本来の学習品質に満たないかもしれません。
    学習するパラメータが多いため、計算時間も当然増えます。
  • 2023年末頃までは学習のスペック要求が高かったりエロが出にくかったりで学習メリットが少なかったのですが
    モデルやUIの改良が進み、普通に学習させることもエロを出すこともできるようになりました。
    学習してみる価値はありますぜ!

動作環境

VRAM量10GB以上のグラボ推奨。(設定次第で8GBに収める事も可)

  • あくまでスレ民の動作報告があったというだけ。
    12GBでも足りないという報告もある。
  • 設定、素材の枚数によって消費が変わる為一概に言えないが12GBあれば解像度1024で学習が可能。

sd-scriptsの導入

  • sd-scriptsのver.0.6.6以上をインストールしよう。

アップグレードする場合の留意点

  • PyTorch等もv2.0系に更新することになるため、不具合に注意。
    「A matching Triton is not available, some optimizations will not be enabled.
    Error caught was: No module named 'triton'
    というエラーが出るかもしれないが、実害はない限り気にしない。Windowsユーザーには関係ない。

学習の実施

v1.5とSDXL学習の違い

v1.5系のLoRA学習ができることを前提に説明します。
sd-scripts公式のReadmeも必ず参照のこと。

公式解説(SDXL):https://github.com/kohya-ss/sd-scripts#sdxl-training

学習モデル

SDXLのモデルを使用します。
VAEを使う場合は、専用のVAEを使います。

学習設定:必須

sd-scripts Ver.によっては解決する可能性があります。

  • train_network.pyの代わりに、sdxl_train_network.pyを使います。
    accelerate launch --num_cpu_threads_per_process 1 sdxl_train_network.py
  • 機能改善のため入れるべきオプション
    • network_train_unet_only
      2層あるtext encoderの学習はsd-scriptsでサポートされていないので、unetのみの学習にする。
      ※VRAMさえあれば、学習自体はできる?要確認。
  • sd-scriptsがSDXL向けにサポートできていないオプション。
    下記のサポート外オプションは外すことになります。
    • weighted_captions
    • clip skip
      そもそも、SDXLでClip skip=2にするメリットがあるのか?という課題はある。

学習設定:推奨

VRAM消費削減のために、入れることになるだろうオプション

  • no_half_vae
  • full_fp16、full_bf16
    • なお、現在full_fp16は推奨されていない。通常はfp32とfp16の混合精度で学習されており
      全てfp16にする事は学習力の低下に繋がる。
    • bitsandbytes 0.41.1以降が必要。bf16は機能に対応したグラボが必要。
  • オプティマイザー
    • adafactorはVRAM消費が少ないので、公式推奨されている。
    • use_8bit_adamなど、軽量のものを使う。
  • cache_latents_to_disk
    学習前にディスクに移動する。
    • accelerate launch(中略) sdxl_train_network.pyコマンドより先に、
      tools/cache_latents.pyを実行する。
    • 筆者もやってみたが、VRAM量削減は体感できず。
      ちゃんと効果はあるのかもしれないが、そもそもモデルが重く確認どころでなかった。
      駄目元でやってみるのが良い程度。
  • (未確認情報)fp8で学習させるとVRAMを節約できるらしい?
    • 多少精度が落ちるとの情報もあるが許容範囲とも言われており、こちらも未検証。加筆求む