はじめに
こんにちは。バクラク事業部 機械学習チームの機械学習エンジニアの上川(@kamikawa)です。
バクラクではAI-OCRという機能を用いて、請求書や領収書をはじめとする書類にOCRを実行し、書類日付や支払い金額などの項目内容をサジェストすることで、お客様が手入力する手間を省いています。
書類から特定の項目を抽出する方法は、自然言語処理や画像認識、近年はマルチモーダルな手法などたくさんあるのですが、今回は項目抽出のための物体検出モデルを構築するまでの手順について紹介します。
Document Layout Analysisとは
Document Layout Analysisとは、文書のレイアウトを解析するタスク(直訳)のことを指します。具体的には、文書内のさまざまな要素(例えば、テキスト、画像、表、見出し、段落など)を抽出し、それぞれの位置や意味などを明らかにすることを目的としています。
Document Layout Analysisは、非構造化データを構造化データに変換し、有意義に利用するために重要な手法であり、様々なアプローチで効率的な情報抽出を可能にします。
物体検出を利用したDocument Object Detection
Document Layout Analysisはdeep learningの隆盛前から長年研究対象になっているテーマであり、ルールベースから学習ベースのモデルまでたくさんの手法が存在します。
Document Layout Analysisの処理としては大きく以下の3つがあります。
- 文書に書かれている項目領域を特定する
- 特定の項目領域に書かれている文字列を認識する
- 文字列の意味を認識する
この中の1. 文書に書かれている項目領域を特定する処理にフォーカスすると、近年は学習ベースの物体検出やインスタンスセグメンテーションなどの手法が良い性能を示しており、このような物体検出によるDocument Layout AnalysisはDocument Object Detectionタスクと定義されたりします。以下ではDocument Object Detectionについて詳しく紹介します。
物体検出ライブラリ
Document Object Detectionを検証するにあたって、Detectron2とMMdetectionという2つのライブラリを採用しました。
Detectron2
Facebook AI Researchが開発しており、元々MaskRCNNのみのリポジトリだったのですが、いつの間にか包括的な物体検出やインスタンスセグメンテーションのライブラリになっていました。 github.com
リリース当初は、物体検出ライブラリで使いやすいものがあまりなかったので、当時のデファクトスタンダード感があったのですが、最近は開発の頻度が低下しており、あまり使われなくなった印象です。
公式から比較的わかりやすいチュートリアルが公開されており、少しクセはありますが自前のデータセットをCOCO形式で準備さえできれば、すぐに検証を始めることができる形となっています。
MMdetection
MMdetectionは、OpenMMLabのプロジェクトの一つで、PyTorchベースの物体検出ライブラリです。物体検出だけでなく、インスタンスセグメンテーションやパノプティックセグメンテーションにも対応しています。
特徴としてはなんといっても対応しているモデルの数の凄まじさです。
こちらがMMdetectionのmodel zooです。Detectron2だとFasterRCNN系列のモデルだけ実装されていたりするのですが、MMdetectionはSOTAのRTMDetやContrastive Learningのような学習戦略の手法まで幅広く実装されています。 mmdetection.readthedocs.io
開発は熱量高く行われている印象で、Issueでは頻繁に議論が交わされています。 物体検出モデルのアーキテクチャがモジュール単位で分解して実装されているので、backbone、headや学習ロジックを簡単に組み替えてカスタマイズできます。とりあえず精度が高いモデルを試したいだけではなく、色々な手法を試したい方におすすめかなと思います。
MMdetectionもDetectron2と同様に、独自形式のフレームワークなので慣れるまで大変かと思いますが、チュートリアルも公開されているので、すぐに使い始めることができる形となっております。 github.com
Document Object Detectionモデルの構築
データセット
Document Object Detectionモデルの学習には、PubLayNetとDocLayNetという2つのデータセットを利用しました。
PubLayNetは、PubMed Central Open Access Subset(商業利用コレクション)から作られた学術論文のデータセットです。クラスは、見出し(Heading)、リスト(List)、表(Table)、図(Figure)、段落(Paragraph)の5クラスあります。
アノテーションファイルはCOCO形式のJSONファイルになっていて、images
(画像情報を保持するリスト)、annotations
(各画像に対応するアノテーションのリスト)、categories
(各カテゴリの情報を保持するリスト)で構成されます。
{ "images": [ { "id": 1, "file_name": "PMC1234567_00001.jpg", "height": 1024, "width": 768 } ], "annotations": [ { "id": 1, "image_id": 1, "category_id": 1, "bbox": [100, 200, 300, 400], "segmentation": [[100, 200, 400, 200, 400, 600, 100, 600]], "area": 120000, "iscrowd": 0 } ], "categories": [ { "id": 1, "name": "Title" }, { "id": 2, "name": "List" }, { "id": 3, "name": "Table" }, { "id": 4, "name": "Figure" }, { "id": 5, "name": "Text" } ] }
DocLayNetは、金融、科学、特許、入札、法律文書、マニュアルなど、幅広い分野の文書で構成されるデータセットです。クラスは、タイトル(Title)、テキスト(Text)、リスト(List)、表(Table)、図(Figure)、セクションヘッダー(Section Header)、フッター(Footer)、ヘッダー(Header)、キャプション(Caption)、参照(Reference)、数式(Equation)の11クラスあります。
こちらもアノテーションファイルはCOCO形式のJSONファイルに従っています。以下、具体例です。
{ "images": [ { "id": 1, "width": 1025, "height": 1025, "file_name": "132a855ee8b23533d8ae69af0049c038171a06ddfcac892c3c6d7e6b4091c642.png", "doc_category": "financial_reports", // 高レベルなドキュメントカテゴリ "collection": "ann_reports_00_04_fancy", // サブコレクション名 "doc_name": "NASDAQ_FFIN_2002.pdf", // オリジナルドキュメントファイル名 "page_no": 9, // オリジナルドキュメントのページ番号 "precedence": 0 // 注釈順序、重複注釈がある場合に非ゼロ } ], "annotations": [ { "image_id": 1, // 対応する画像のID "category_id": 1, // クラスラベルのID "bbox": [66.99346405228758, 112.10344760101009, 290.869358251634, 13.66279703282828], // バウンディングボックスの座標 [x, y, width, height] "area": 3963.6758, // バウンディングボックスの面積 "iscrowd": 0 // 群集注釈(0 = 単一オブジェクト、1 = 群集) } ], "categories": [ { "id": 1, "name": "Title" }, { "id": 2, "name": "Text" }, { "id": 3, "name": "List" }, { "id": 4, "name": "Table" }, { "id": 5, "name": "Figure" }, { "id": 6, "name": "Section Header" }, { "id": 7, "name": "Footer" }, { "id": 8, "name": "Header" }, { "id": 9, "name": "Caption" }, { "id": 10, "name": "Reference" }, { "id": 11, "name": "Equation" } ] }
DocLayNetはCOCO形式のアノテーションのカスタムフィールドとして、metadata
, cells
というフィールドを持ちます。
{ "metadata": { "page_hash": "132a855ee8b23533d8ae69af0049c038171a06ddfcac892c3c6d7e6b4091c642", // 一意の識別子、ファイル名と同じ "original_filename": "NASDAQ_FFIN_2002.pdf", // オリジナルドキュメントのファイル名 "page_no": 9, // オリジナルドキュメントのページ番号 "num_pages": 28, // オリジナルドキュメントの総ページ数 "original_width": 612, // オリジナルドキュメントの幅(ピクセル、72 ppi) "original_height": 792, // オリジナルドキュメントの高さ(ピクセル、72 ppi) "coco_width": 1025, // PNGおよびCOCO形式での幅(ピクセル) "coco_height": 1025, // PNGおよびCOCO形式での高さ(ピクセル) "collection": "ann_reports_00_04_fancy", // サブコレクション名 "doc_category": "financial_reports" // 高レベルなドキュメントカテゴリ }, "cells": [ // デジタルPDFデータのすべてのテキストセル { "bbox": [66.99346405228758, 112.10344760101009, 290.869358251634, 13.66279703282828], // バウンディングボックス } ] }
doc_category
などを活用することで、文書の分類タスクを解いたり、文書種別ごとにモデル構築や評価などを行うことが可能となっています。
物体検出ライブラリは色々ありますが、基本的にはCOCO形式のデータセットであれば、画像が置かれたディレクトリとアノテーションのJSONファイルを指定するだけで学習することができます。 データセットを自前で構築し、物体検出モデルを学習する際は、これらのCOCO形式のフォーマットにならってアノテーションを作成すると、すぐに検証を始めることができます。
学習&推論
Detectron2とMMdetectionの両方でDocument Object Detection用モデルの学習と推論を行いました。実装は、それぞれの公式チュートリアルに沿った形となります。
Detectron2の場合
学習を進めていくにあたって、物体検出モデルの設定を読み込む必要があります。
Detectron2のリポジトリのconfigs以下に設定ファイルがあるので、使いたいモデルの設定ファイルを読み込みます。設定ファイルはyaml形式で書かれています。
今回は、faster_rcnn_R_50_FPN_3x
というモデルをベースにモデルを構築します。
from detectron2.config import get_cfg cfg = get_cfg() cfg.merge_from_file("detectron2/configs/COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml") # 学習済みの重みを初期値とする場合は、このようにuriを指定すると初期化時にダウンロードして、ロードしてくれる cfg.MODEL.WEIGHTS = "detectron2://COCO-Detection/faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl" cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 cfg.MODEL.ROI_HEADS.NUM_CLASSES = 5 # PubLayNetのクラス数
次に、データセットをDetectron2で読み込みます。Detectron2ではPubLayNetを使って学習してみます。register_coco_instances
という関数にデータセットへのPATHを指定するだけでDetectron2にCOCO形式のデータセットを登録することができます。
from detectron2.data import MetadataCatalog from detectron2.data.datasets import register_coco_instances publaynet_path = "~/datasets/publaynet" register_coco_instances( "publaynet_train", {}, f"{publaynet_path}/train.json", f"{publaynet_path}/train" ) register_coco_instances( "publaynet_val", {}, f"{publaynet_path}/val.json", f"{publaynet_path}/val" ) class_labels = ['text', 'title', 'list', 'table', 'figure'] MetadataCatalog.get("publaynet_train").thing_classes = class_labels MetadataCatalog.get("publaynet_val").thing_classes = class_labels # データセット cfg.DATASETS.TRAIN = ("publaynet_train",) # ここにカンマ区切りで別のデータセットを追加することも cfg.DATASETS.TEST = ("publaynet_val",) cfg.DATALOADER.NUM_WORKERS = 8
学習時の設定を追加し、DefaultTrainer
というクラスを使って学習を回します。
import os from detectron2.engine import DefaultTrainer cfg.SOLVER.IMS_PER_BATCH = 2 cfg.SOLVER.BASE_LR = 0.00025 cfg.SOLVER.MAX_ITER = 30000 cfg.TEST.DETECTIONS_PER_IMAGE = 100 cfg.INPUT.MIN_SIZE_TRAIN = (580, 612, 644, 676, 708, 740) os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) trainer = DefaultTrainer(cfg) trainer.resume_or_load(resume=False) trainer.train()
そして、PubLayNetで学習済みのDocument Object Detection用モデルで推論してみます。Detectron2には推論用のDefaultPredictor
と可視化用のVisualizer
というクラスがあるのでそれを利用します。
from detectron2.engine import DefaultPredictor from detectron2.utils.visualizer import ColorMode from detectron2.utils.visualizer import Visualizer cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set a custom testing threshold predictor = DefaultPredictor(cfg) im = cv2.imread("datasets/publaynet/test/PMC1481631_00004.jpg") outputs = predictor(im) # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format v = Visualizer(im[:, :, ::-1], metadata=MetadataCatalog.get("publaynet_train"), scale=0.5, instance_mode=ColorMode.IMAGE_BW) # remove the colors of unsegmented pixels. This option is only available for segmentation models out = v.draw_instance_predictions(outputs["instances"].to("cpu")) Image.fromarray(out.get_image())
MMdetectionの場合
MMdetectionもDetectron2と同様に、設定ファイルで物体検出モデルの設定を読み込みます。
MMdetectionのリポジトリのconfigs以下に様々なモデルの設定ファイルが置かれてますが、Detectron2とは異なり、Pythonで設定を記述しています。
MMdetectionのチュートリアルでは、モデル、データセット、学習、推論などのすべての設定をPythonファイルで管理しているので、それにならって進めていきます。
今回はyolox_s_8xb8-300e
というモデルを使うので、configs/yolox/yolox_s_8xb8-300e.py
というファイルからconfigs/yolox/yolox_s_8xb8-300e_doclaynet.py
というファイルを作成します。
以下、設定ファイルになりますが、すべての設定をこのファイルに書き込んでおり、分かりづらいので、編集した箇所には# NOTE:
コメントをつけております。
configs/yolox/yolox_s_8xb8-300e_doclaynet.py
_base_ = ["_base_/schedules/schedule_1x.py", "_base_/default_runtime.py", "yolox/yolox_tta.py"] img_scale = (640, 640) # width, height # model settings model = dict( type="YOLOX", data_preprocessor=dict( type="DetDataPreprocessor", pad_size_divisor=32, batch_augments=[dict(type="BatchSyncRandomResize", random_size_range=(480, 800), size_divisor=32, interval=10)], ), backbone=dict( type="CSPDarknet", deepen_factor=0.33, widen_factor=0.5, out_indices=(2, 3, 4), use_depthwise=False, spp_kernal_sizes=(5, 9, 13), norm_cfg=dict(type="BN", momentum=0.03, eps=0.001), act_cfg=dict(type="Swish"), ), neck=dict( type="YOLOXPAFPN", in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1, use_depthwise=False, upsample_cfg=dict(scale_factor=2, mode="nearest"), norm_cfg=dict(type="BN", momentum=0.03, eps=0.001), act_cfg=dict(type="Swish"), ), bbox_head=dict( type="YOLOXHead", num_classes=11, # NOTE: DocLayNetのクラス数 in_channels=128, feat_channels=128, stacked_convs=2, strides=(8, 16, 32), use_depthwise=False, norm_cfg=dict(type="BN", momentum=0.03, eps=0.001), act_cfg=dict(type="Swish"), loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, reduction="sum", loss_weight=1.0), loss_bbox=dict(type="IoULoss", mode="square", eps=1e-16, reduction="sum", loss_weight=5.0), loss_obj=dict(type="CrossEntropyLoss", use_sigmoid=True, reduction="sum", loss_weight=1.0), loss_l1=dict(type="L1Loss", reduction="sum", loss_weight=1.0), ), train_cfg=dict(assigner=dict(type="SimOTAAssigner", center_radius=2.5)), # In order to align the source code, the threshold of the val phase is # 0.01, and the threshold of the test phase is 0.001. test_cfg=dict(score_thr=0.01, nms=dict(type="nms", iou_threshold=0.65)), ) # dataset settings data_root = "/home/hoge/datasets/doclaynet" # NOTE: DocLayNetへのPATH dataset_type = "CocoDataset" # PASCALVOC,CityScapesとかの形式も選択可能 # Example to use different file client # Method 1: simply set the data root and let the file I/O module # automatically infer from prefix (not support LMDB and Memcache yet) # data_root = 's3://openmmlab/datasets/detection/coco/' # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 # backend_args = dict( # backend='petrel', # path_mapping=dict({ # './data/': 's3://openmmlab/datasets/detection/', # 'data/': 's3://openmmlab/datasets/detection/' # })) backend_args = None train_pipeline = [ dict(type="Mosaic", img_scale=img_scale, pad_val=114.0), dict( type="RandomAffine", scaling_ratio_range=(0.1, 2), # img_scale is (width, height) border=(-img_scale[0] // 2, -img_scale[1] // 2), ), dict(type="MixUp", img_scale=img_scale, ratio_range=(0.8, 1.6), pad_val=114.0), dict(type="YOLOXHSVRandomAug"), dict(type="RandomFlip", prob=0.5), # According to the official implementation, multi-scale # training is not considered here but in the # 'mmdet/models/detectors/yolox.py'. # Resize and Pad are for the last 15 epochs when Mosaic, # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook. dict(type="Resize", scale=img_scale, keep_ratio=True), dict( type="Pad", pad_to_square=True, # If the image is three-channel, the pad value needs # to be set separately for each channel. pad_val=dict(img=(114.0, 114.0, 114.0)), ), dict(type="FilterAnnotations", min_gt_bbox_wh=(1, 1), keep_empty=False), dict(type="PackDetInputs"), ] train_dataset = dict( # use MultiImageMixDataset wrapper to support mosaic and mixup type="MultiImageMixDataset", dataset=dict( type=dataset_type, data_root=data_root, ann_file="COCO/train.json", # NOTE: COCO形式のアノテーションJSONファイルのPATH data_prefix=dict(img="PNG/"), # NOTE: 画像ディレクトリのPATH pipeline=[dict(type="LoadImageFromFile", backend_args=backend_args), dict(type="LoadAnnotations", with_bbox=True)], metainfo=dict( classes=( "Caption", "Footnote", "Formula", "List-item", "Page-footer", "Page-header", "Picture", "Section-header", "Table", "Text", "Title", ) # NOTE: DocLayNetのクラス ), filter_cfg=dict(filter_empty_gt=False, min_size=32), backend_args=backend_args, ), pipeline=train_pipeline, ) test_pipeline = [ dict(type="LoadImageFromFile", backend_args=backend_args), dict(type="Resize", scale=img_scale, keep_ratio=True), dict(type="Pad", pad_to_square=True, pad_val=dict(img=(114.0, 114.0, 114.0))), dict(type="LoadAnnotations", with_bbox=True), dict(type="PackDetInputs", meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor")), ] train_dataloader = dict( batch_size=8, num_workers=4, persistent_workers=True, sampler=dict(type="DefaultSampler", shuffle=True), dataset=train_dataset ) val_dataloader = dict( batch_size=8, num_workers=4, persistent_workers=True, drop_last=False, sampler=dict(type="DefaultSampler", shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, ann_file="COCO/val.json", # NOTE: COCO形式のアノテーションJSONファイルのPATH data_prefix=dict(img="PNG/"), # NOTE: 画像ディレクトリのPATH test_mode=True, pipeline=test_pipeline, metainfo=dict( classes=( "Caption", "Footnote", "Formula", "List-item", "Page-footer", "Page-header", "Picture", "Section-header", "Table", "Text", "Title", ) # NOTE: DocLayNetのクラス ), backend_args=backend_args, ), ) test_dataloader = val_dataloader # NOTE: COCO形式のアノテーションJSONファイルのPATH val_evaluator = dict(type="CocoMetric", ann_file="datasets/doclaynet/COCO/val.json", metric="bbox", backend_args=backend_args) test_evaluator = val_evaluator # training settings # NOTE: 今回は簡単な検証を行うので、少しオリジナルより短くした max_epochs = 50. # NOTE: 今回は簡単な検証を行うので、少しオリジナルより短くした num_last_epochs = 10 interval = 5 # NOTE: validationのinterval train_cfg = dict(max_epochs=max_epochs, val_interval=interval) # optimizer # default 8 gpu base_lr = 0.001 optim_wrapper = dict( type="OptimWrapper", optimizer=dict(type="SGD", lr=base_lr, momentum=0.9, weight_decay=5e-4, nesterov=True), paramwise_cfg=dict(norm_decay_mult=0.0, bias_decay_mult=0.0), ) # learning rate param_scheduler = [ dict( # use quadratic formula to warm up 5 epochs # and lr is updated by iteration # TODO: fix default scope in get function type="mmdet.QuadraticWarmupLR", by_epoch=True, begin=0, end=5, convert_to_iter_based=True, ), dict( # use cosine lr from 5 to 285 epoch type="CosineAnnealingLR", eta_min=base_lr * 0.05, begin=5, T_max=max_epochs - num_last_epochs, end=max_epochs - num_last_epochs, by_epoch=True, convert_to_iter_based=True, ), dict( # use fixed lr during last 15 epochs type="ConstantLR", by_epoch=True, factor=1, begin=max_epochs - num_last_epochs, end=max_epochs, ), ] default_hooks = dict( checkpoint=dict( interval=interval, max_keep_ckpts=3, # only keep latest 3 checkpoints ) ) custom_hooks = [ dict(type="YOLOXModeSwitchHook", num_last_epochs=num_last_epochs, priority=48), dict(type="SyncNormHook", priority=48), dict(type="EMAHook", ema_type="ExpMomentumEMA", momentum=0.0001, update_buffers=True, priority=49), ] # NOTE: `auto_scale_lr` is for automatically scaling LR, # USER SHOULD NOT CHANGE ITS VALUES. # base_batch_size = (8 GPUs) x (8 samples per GPU) auto_scale_lr = dict(base_batch_size=64)
学習はMMdetectionが提供するスクリプトで行います。シングルノード × 1 GPUからマルチノード × マルチGPUまでのスクリプトがあるので、それぞれの実行例です。
$ git clone git@github.com:open-mmlab/mmdetection.git $ cd mmdetection # single node, 1 gpuの場合 $ python tools/train.py configs/yolox/yolox_s_8xb8-300e_doclaynet.py # single node, multi GPUの場合 $ ./tools/dist_train.sh configs/yolox/yolox_s_8xb8-300e_doclaynet.py [GPUS] # multi node, multi GPUの場合(試してません), 環境変数を指定する: NNODES=ノード数 NODE_RANK=ノードのランク PORT=ポート MASTER_NODE=マスターノードのaddr $ ./tools/dist_train.sh configs/yolox/yolox_s_8xb8-300e_doclaynet.py [GPUS]
そしてMMdetectionにDetInferencer
というクラスがあるので、それを使って推論と可視化を行います。
from mmdet.apis import DetInferencer config = "configs/yolox/yolox_s_8xb8-300e_doclaynet.py" checkpoint = "work_dirs/yolox_s_8xb8-300e_doclaynet/epoch_50.pth" # Set the device to be used for evaluation device = 'cuda:0' # Initialize the DetInferencer inferencer = DetInferencer(config, checkpoint, device) # Perform inference result = inferencer("./test.png", out_dir="./outputs")
最後に
今回はDocument Object Detectionに焦点を当て、データセットの作成方法から物体検出モデルを学習するまでの手順について紹介しました。
OCRに関連する技術は日々進歩を続ける中、その技術でお客様の体験をバクラクにするための仲間がまだまだ必要で、一緒に働いてくれる仲間を大募集しております!
少しでも興味を持ってくださった方!ご応募をお待ちしております!