機械学習エンジニアの吉田です。本記事では、LayoutLMv3*1というモデルをバクラクで取り扱っている帳票で事前学習を行い、それをファインチューニングして項目推定タスクに取り組んでいる話をご紹介します。
背景
LayerXで提供しているバクラクでは帳票をアップロードするだけで支払金額や支払期日などを自動で読み取り補完してくれるOCR機能があります。このOCR機能には大きく2つの処理があります。
- 帳票に書かれている文字列を認識し検出すること
- 検出された文字列から支払金額や支払期日などの項目を推定すること
2つ目の項目推定において現在はRoBERTa*2というモデルを使っています。RoBERTaでも精度高く推定することができるのですが、複雑なレイアウトの場合に誤って推定してしまうケースがどうしても発生してしまいます。RoBERTaはOCRで検出したテキストだけを使ったモデルであるためこのような複雑なレイアウトの場合に弱いのではないか、という仮説があります。しかし、バクラクにおいてOCRはコアな機能であるためどのような帳票であっても高い精度で読み取れることが期待されます。
このような背景から現在、LayoutLMv3というモデルを検証しています。LayoutLMv3はテキストだけでなく画像やbounding boxを取り込んだモデルであるため複雑なレイアウトの場合でも精度高く推定することができるようになるのではないか、という期待があります。
実は以前にもLayoutLMv3の検証をしたことはありましたが、そのときはRoBERTaの方が精度が高かったためRoBERTaを採用したという経緯があります。 そのあたりの話は以下の記事に詳しく書かれています。
しかし以前はbaseモデルでしか検証されておらず、より性能の高いlargeモデルで検証できていなかったことや、その他にも改善の余地があったため改めて検証を行いました。
LayoutLMv3
まず、LayoutLMv3とはどのようなモデルであるか簡単に説明します。
LayoutLMv3のアーキテクチャは下図のようになっています。 画像とテキストからそれぞれ埋め込みを取得し、連結したものをTransformer Encoderに入力しています。 Transformer Encoderは事前学習済みのRoBERTaで初期化します。
テキストに関してはまず文書画像に対してOCRを適用しテキストとbounding boxを取得します。 LayoutLMv3では単語レベルではなくセグメントレベルのbounding boxを採用しています。 セグメントレベルのbounding boxでは以下の図 (StructuralLM*3から引用) のように同じ意味を表している複数の単語 (背景色が同じ単語) を包含するbounding boxを用います。 論文にはセグメントレベルのbounding boxの作成方法については書かれていませんでしたが、Issuesを見るとOCRで検出した行単位のセグメントを使っているようでした*4
OCRで取得したテキストをトークナイズし、事前学習済みのRoBERTaで初期化されたEmbedding Layerからトークンの埋め込みを取得します。
その後、1D, 2D position embeddingsを加算します。
1D position embeddingsはpaddingを除くトークンに対するシーケンシャルなインデックスの埋め込みです。
2D position embeddingsはbounding boxの4点と幅と高さの合計6個の埋め込みを連結しています。
次に画像に関してです。 画像はまず前処理で224 x 224にリサイズ後、パッチサイズ16 x 16のパッチに分割されます。(画像パッチの総数は14 x 14) それぞれのパッチをLinear Embeddingを通して埋め込みを取得し、先頭にCLSトークンを連結後、最後にposition embeddingsを加算します。
以上の処理はhuggingface/transformersでは LayoutLMv3Model に実装されています*5
事前学習
LayoutLMv3の事前学習はMasked Language Modeling (MLM), Masked Image Modeling (MIM), Word Patch Alignment (WPA)の3つのタスクがあります。
現在公開されているLayoutLMv3のコードはファインチューニングだけで事前学習のコードは公開されていないのでこれらのタスクは自前で実装する必要があります。
Masked Language Modeling (MLM)
BERT*6のMLMと同様に一部のトークンをマスクして元のトークンを予測します。
BERTではマスクするトークンを選択する際に一様なランダムサンプリングを用いますが、LayoutLMv3ではspan maskingでマスクしています。
SpanBERT*7で用いられているspan maskingはまず幾何分布からスパン長をサンプリングします。(LayoutLMv3ではポアソン分布からサンプリング)
次にランダムに選んだ開始点からスパン長分マスキングします。
ランダムサンプリングだとサブワードだけマスクされてしまいタスクが容易にになってしまうのに対してspan maskingだとひとつづきの単語をマスクされることでタスクの難易度が上がり性能向上に繋がるとあります。
MLMに関してはBERTの実装*8が、span maskingに関してはSpanBERTの実装*9が参考になります。
Masked Image Modeling (MIM)
MIMはBEiT*10で提案された画像に関する自己教師あり学習であり、224 x 224にリサイズされた画像をdVAEに通して取得した14 x 14の画像トークンを正解ラベルとします。 一方で画像をパッチサイズ16 x 16のパッチに分割 (画像パッチの総数は14 x 14) し、画像パッチの一部をランダムにマスクして破損した入力をTransformerに与えて画像トークンを復元するように学習します。
LayoutLMv3で使っているdVAEはDiT*11で事前学習されたものを使っています。BEiTのdVAEとの違いは、BEiTではDALL-E*12のencoderを使っているのに対して、DiTではDALL-Eと同様のアーキテクチャで大量の文書画像で学習されている点です。
以下の図は一番左が元画像で続いてDiT、DALL-EのdVAEで再構成した画像となっており、DALL-Eと比較してDiTの方がより鮮明に文書画像が復元できていることが分かります。
しかし、DiTのdVAEは公開されていないので、事前学習済みのencoderを使う場合はMITライセンスのDALL-Eのencoderが候補となります。
マスクする画像パッチを選択する際はblockwise maskingを用いています。 blockwise maskingでは複数のパッチの集合が矩形となるようにマスキングを行います。 BEiTの実装*13を参考にマスクすると以下のようになります。
MIMの実装に関してはBEiTの実装*14が参考になります。 画像パッチのマスキングは正規分布で初期化された学習可能なパラメータを用いてマスク位置の埋め込みと差し替えています。
Word Patch Alignment (WPA)
WPAはテキスト位置の画像パッチがMIMによってマスクされているかどうかを予測します。
MLMとMIMではテキストと画像をそれぞれ単独でしか学習することができないので、テキストと画像のモダリティ間のアライメントを学習させるのが目的です。
テキストがマスクされているかどうかの判定については論文には言及されていませんでしたが、Issuesによると98%の領域がマスクされているかどうかを閾値としているようです*15
検証
事前学習には時間もお金もかかります。極力手戻りが発生しないように不確実性を下げつつ検証を進めていきました。
まずは実装した事前学習にバグが無いか切り分けたかったので、事前学習以外にバグを埋め込まないように以下のようにミニマムな実装としました。
- バクラクデータセットではなく、オープンデータセット IIT CDIP 1.0 dataset*16 を使用
- テキストと画像の前処理は
LayoutLMv3Processor
を使用- 画像のOCRは
LayoutLMv3ImageProcessor
に組み込みのTesseractを使用 - テキストのtokenizerやToken Embeddingは事前学習済みのRoBERTaを使わずにデフォルトの構成を使用
- 画像のOCRは
- 軽量なbaseモデルを使用
この時点で数千ステップ程度は継続して事前学習が回ること、各タスクのlossが減少することで致命的なバグが無いことを確認しました。
次に事前学習にどの程度のコストが必要となるか見積もりました。
論文ではデータセットのサイズ1100万、バッチサイズ2048で50万ステップの事前学習を行ったとありますが、Issues*17には15万ステップでもほぼ同等の精度が出るとのコメントがありました。
一旦15万ステップを目安としてA100(40GB) x 16で見積もったところ、1回の実験で$1800~$7000程度かかりそうということが分かりました。見積もりに大きな幅がありますが、これはGradient Accumulationを利用するかどうかで大きく変動します。 (この時点ではVRAMの効率化ができていなかったため少し過大な見積もりになっていたかとは思います。)
一方で過去のRoBERTaでの事前学習の経験からそこまで大きなデータセットやバッチサイズで学習させなくても下流の項目推定タスクの性能が出るのではないかという仮説はありました。 一度の事前学習でうまくいく保証はないので、まずは1回あたりのコストを下げて試行回数を増やす方向で考えました。
ここ最近はGCPでA100のスポットインスタンスを確保することは困難なので、最初はV100(16GB)のスポットインスタンスで検証することを考えました。
AMP, Gradient Checkpointing, DeepSpeed ZeROを活用することでbaseモデルであれば、GPU1個あたりバッチサイズ24まで積めたので8GPUであれば戦えそうな感じになりました。
V100 x 8のスポットインスタンスであれば、1日中トレーニングしても$200程度で回すことができ、baseモデルであれば15万ステップを2日で学習できる見込みとなりました。
このタイミングで検証用のデータセットからバクラクのデータセットに切り替え、OCRもTesseractからVision APIに切り替えて事前学習を行いました。ファインチューニングして精度を検証してみたところ、RoBERTaの精度までは到達できないまでも数ポイントの差しかなかったためいけそうな感触を得ました。
ここまでで不確実性はある程度潰せたので、largeモデルへの切り替えと、Transformer EncoderとToken Embeddingに事前学習済みのRoBERTa (nlp-waseda/roberta-large-japanese-seq512) *18 を用いて事前学習を行いました。
A100(80GB)を用いて、データセットのサイズ約50万、バッチサイズ150、50万ステップで事前学習を回しており、実はこのブログを書いている時点でもまだ事前学習は終わっていないのですが、途中でファインチューニングして精度を出してみたところすでにいくつかの項目でRoBERTaの精度を上回っています!
まとめ
本記事では、LayoutLMv3をバクラクのデータセットで事前学習を行い、それをファインチューニングして項目推定タスクに取り組んでいる話をご紹介しました。直近の実験ではRoBERTaの精度を上回るなど良い結果も見えてきています。 RoBERTaとLayoutLMv3の推論結果を比較することで、複雑なレイアウトの場合に改善しているか検証してみたいと思います。 また、さらなる改善に向けて以下の実験にも取り組んでいきたいと考えています。
- 画像トークナイザーとしてDALL-EのdVAEを使っているが、DiTのように自前のデータセットでdVAEを学習させる
- MLMはBERTの実装を用いているが、span maskingによるマスキングを行う
- より大量のデータセットを用いた事前学習
- テキストの多言語対応
続報があれば改めてブログにしたいと思います!
最後に
LayerXにはOCR以外にも機械学習で解きたい課題がいっぱいあります!興味を持たれた方は是非カジュアル面談からでもお気軽にどうぞ!
jobs.layerx.co.jp jobs.layerx.co.jp
*1:https://arxiv.org/abs/2204.08387
*2:https://arxiv.org/abs/1907.11692
*3:https://arxiv.org/abs/2105.11210
*4:https://github.com/microsoft/unilm/issues/838
*5:https://github.com/huggingface/transformers/blob/main/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
*6:https://arxiv.org/abs/1810.04805
*7:https://arxiv.org/abs/1907.10529
*8:https://github.com/huggingface/transformers/blob/main/src/transformers/data/data_collator.py
*9:https://github.com/facebookresearch/SpanBERT/blob/main/pretraining/fairseq/data/masking.py
*10:https://arxiv.org/abs/2106.08254
*11:https://arxiv.org/abs/2203.02378
*12:https://github.com/openai/DALL-E
*13:https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py
*14:https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py
*15:https://github.com/microsoft/unilm/issues/785
*16:https://ir.cs.georgetown.edu/downloads/sigir06cdipcoll_v05-with-authors.pdf
*17:https://github.com/microsoft/unilm/issues/917
*18:https://huggingface.co/nlp-waseda/roberta-large-japanese-seq512