LayoutLM:深入解析文档图像理解的强大模型
1. 简要介绍
在数字化时代,我们每天都会接触到大量的文档,包括扫描件、表格、收据等。如何让计算机理解这些 包含文本和布局信息的文档,一直是人工智能领域的研究重点。传统的自然语言处理(NLP)模型主要关注文本内容,而忽略了文档的布局和视觉信息,这在处理文档图像时会遇到瓶颈. 为了解决这个问题,微软在2020年6月推出了 LayoutLM 模型.
- 背景历史:
- 在LayoutLM之前,NLP模型主要关注文本输入,而计算机视觉模型主要关注图像输入.
- LayoutLM的出现,首次将 图像、文本和2D位置 信息作为输入,实现了 多模态 信息处理.
- 开发团队: LayoutLM由 Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, 和 Ming Zhou 共同开发.
- 功能:
- LayoutLM旨在 理解文档图像,从而实现 信息提取、表单理解、收据理解和文档分类 等任务.
- 它通过同时建模文本和布局信息之间的交互,从而 显著提高 了文档图像理解的性能.
- LayoutLM 可以从 扫描文档或图像 中提取特定的、重点信息.
2. 架构设计
LayoutLM 的架构基于 BERT (Bidirectional Encoder Representations from Transformers). 它在 BERT 的基础上增加了 两种新的输入嵌入:
- 2D 位置嵌入 (2D Position Embeddings): 用于表示文档中 文本的空间位置. 与传统的只考虑单词顺序的位置嵌入不同,2D 位置嵌入使用每个单词的 边界框坐标 (x0, y0, x1, y1) 来定义其在页面上的位置. 文档的左上角被视为坐标系统的原点 (0, 0). 这些坐标被归一化到 0-1000 的范围内,然后嵌入到模型可以理解的数值表示中.
- 图像嵌入 (Image Embeddings): 用于 整合视觉信息. LayoutLM 将图像分割成与 OCR 文本对应的区域,并利用这些区域的视觉特征生成图像嵌入. 图像嵌入有助于模型理解文档的视觉风格,从而增强文档理解能力.
预训练 (Pre-training):
- LayoutLM 使用 Masked Visual-Language Model (MVLM) 进行预训练. MVLM 是受掩码语言模型启发的技术,但它同时考虑文本和2D位置嵌入作为输入. 模型学习预测被掩码的单词,通过上下文的文本和空间位置信息进行预测.
- LayoutLM 还使用 Multi-label Document Classification (MDC) 进行预训练. 该任务训练 LayoutLM 处理带有多个标签的扫描文档,使其能够从多个领域聚合知识并生成更好的文档级别表示,尽管它不是大型模型预训练的必要条件.
- LayoutLM的预训练使用了 IIT-CDIP Test Collection 1.0 数据集,该数据集包含超过600万份文档和1100万份扫描文档图像.
3. 能处理的文档类型
LayoutLM 擅长处理那些 布局和视觉信息对于理解内容至关重要 的文档. 包括以下类型:
- 表单 (Forms): LayoutLM 在表单理解任务上取得了非常好的效果,能够准确地处理具有特定字段和布局的结构化文档. FUNSD 数据集 通常用于训练和评估 LayoutLM 的表单理解能力.
- 收据 (Receipts): LayoutLM 在收据理解任务中也表现出色. 它可以从收据中提取数据,并利用文本和布局信息. SROIE 数据集 用于微调 LayoutLM 的收据数据.
- 扫描文档 (Scanned documents): LayoutLM 能够有效地处理扫描文档,同时建模文本和布局信息之间的交互.
- 商务文档 (Business Documents): LayoutLM 可应用于各种商务文档,包括:
- 采购订单 (Purchase orders)
- 财务报告 (Financial reports)
- 商业邮件 (Business emails)
- 销售协议 (Sales agreements)
- 供应商合同 (Vendor contracts)
- 信件 (Letters)
- 发票 (Invoices)
- 简历 (Resumes)
- 其他视觉丰富的文档 (Other Visually Rich Documents): LayoutLM 适用于任何视觉丰富的文档,在这些文档中,布局显著增强了语言表示.
4. 使用技巧
- OCR 引擎: 使用 OCR (Optical Character Recognition) 引擎 (例如 Tesseract)从文档图像中提取文本及其对应的边界框.
- 边界框归一化: 在将边界框坐标输入 LayoutLM 之前,将它们归一化到 0-1000 范围. 通过将边界框坐标除以文档图像的原始宽度和高度,然后乘以 1000 进行归一化.
- 特殊标记: LayoutLM 使用特殊标记来处理文本,包括:
- [CLS]: 分类标记,用于序列分类,并且是序列的第一个标记.
- [SEP]: 分隔符标记,用于分隔多个序列.
- [PAD]: 填充标记,用于填充不同长度的序列.
- [MASK]: 掩码标记,用于掩码语言建模.
- [UNK]: 未知标记,用于表示词汇表中未知的单词.
- 选择合适的 Tokenizer: 使用 LayoutLMTokenizer 或 LayoutLMTokenizerFast 进行分词. LayoutLMTokenizerFast 是一个更快的版本,基于 Hugging Face 的 tokenizers 库.
5. 运行环境要求
LayoutLM的运行环境要求主要包括以下几个方面:
- 编程语言和框架:LayoutLM可以使用 PyTorch 或 TensorFlow 框架进行实现和训练。
- PyTorch 是一个开源的机器学习库,常用于实现神经网络和深度学习模型。
- TensorFlow 是另一个流行的开源机器学习库,也用于实现神经网络和深度学习模型。
- Hugging Face Transformers 库:这是使用LayoutLM的核心库,提供了预训练模型、tokenizer 以及其他工具。
- 这个库提供了LayoutLM模型的各种实现,包括用于不同任务的变体,例如LayoutLMModel, LayoutLMForMaskedLM, LayoutLMForSequenceClassification, LayoutLMForTokenClassification 和 LayoutLMForQuestionAnswering。
- OCR引擎:需要一个 OCR (光学字符识别) 引擎 从文档图像中提取文本及其对应的边界框。
- 常用的OCR引擎是 Tesseract。
- OCR引擎将图像中的文本转换为机器可读的文本,并提供位置嵌入所需的坐标。
- 图像处理库:需要图像处理库来处理文档图像,例如 Pillow (PIL)。
- 数据处理库:需要使用数据处理库,例如 NumPy 和 Pandas 进行数据处理。
- 硬件要求: 如果要进行模型的训练,GPU 可以显著加快训练速度。
- Python 环境: 需要 Python 编程环境,并安装所需的库。
- Tokenizer:需要使用 LayoutLMTokenizer 或 LayoutLMTokenizerFast 进行分词。 LayoutLMTokenizerFast 是一个更快的版本,基于Hugging Face的tokenizers库。
- Tokenizer负责将文本分割成模型可以理解的token。
- 数据集: 不同的任务需要不同的数据集。例如,FUNSD数据集用于表单理解,SROIE数据集用于收据理解,RVL-CDIP数据集用于文档图像分类。
总而言之,使用LayoutLM需要一个配置了适当库(如 Transformers, PyTorch 或 Tensorflow,以及OCR引擎)的Python环境,以及一个能够进行数据预处理和模型训练的平台。
6. 代码示例
以下是一个使用 LayoutLM 进行序列分类的 PyTorch 代码示例:
import os
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
import pytesseract
from PIL import Image, ImageDraw, ImageFont
import torch
from datasets import Dataset, Features, Sequence, ClassLabel, Value, Array2D
from transformers import LayoutLMTokenizer, LayoutLMForSequenceClassification, AdamW
# Load the dataset
# Assuming you have a dataframe named 'df' with columns 'image_path', 'words', 'bbox', 'label'
# The bounding box coordinates should be normalized
# Create a dictionary for label to index mapping
labels = df['label'].unique().tolist()
label2idx = {label: idx for idx, label in enumerate(labels)}
# Load the tokenizer and model
tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
# Define a function to encode training examples
def encode_training_example(example, max_seq_length=512, pad_token_box=):
words = example['words']
normalized_word_boxes = example['bbox']
assert len(words) == len(normalized_word_boxes)
token_boxes = []
for word, box in zip(words, normalized_word_boxes):
word_tokens = tokenizer.tokenize(word)
token_boxes.extend([box] * len(word_tokens))
special_tokens_count = 2
if len(token_boxes) > max_seq_length - special_tokens_count:
token_boxes = token_boxes[: (max_seq_length - special_tokens_count)]
token_boxes = [] + token_boxes + []
encoding = tokenizer(' '.join(words), padding='max_length', truncation=True)
input_ids = tokenizer(' '.join(words), truncation=True)["input_ids"]
padding_length = max_seq_length - len(input_ids)
token_boxes += [pad_token_box] * padding_length
encoding['bbox'] = token_boxes
encoding['label'] = label2idx[example['label']]
assert len(encoding['input_ids']) == max_seq_length
assert len(encoding['attention_mask']) == max_seq_length
assert len(encoding['token_type_ids']) == max_seq_length
assert len(encoding['bbox']) == max_seq_length
return encoding
# Function to prepare data loaders from dataframe
def training_dataloader_from_df(data_df):
dataset = Dataset.from_pandas(data_df)
features = Features({
'words': Sequence(Value('string')),
'bbox': Sequence(Sequence(Value('int64'))),
'label': Value('string'),
})
encoded_dataset = dataset.map(encode_training_example, features=features, remove_columns=dataset.column_names)
encoded_dataset.set_format(type='torch', columns=['input_ids','bbox', 'attention_mask', 'token_type_ids', 'label'])
dataloader = torch.utils.data.DataLoader(encoded_dataset, batch_size=4, shuffle=True)
return dataloader
# Split train and validation datasets
train_data, valid_data = train_test_split(df, test_size=0.2, random_state=42)
# Create dataloaders
train_dataloader = training_dataloader_from_df(train_data)
valid_dataloader = training_dataloader_from_df(valid_data)
# Define the device to train on
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
model = LayoutLMForSequenceClassification.from_pretrained(
"microsoft/layoutlm-base-uncased", num_labels=len(label2idx)
)
model.to(device);
# Define optimizer
optimizer = AdamW(model.parameters(), lr=4e-5)
# Training loop
num_epochs = 3
for epoch in range(num_epochs):
print("Epoch:", epoch)
training_loss = 0.0
training_correct = 0
model.train()
for batch in tqdm(train_dataloader):
labels = batch["label"].to(device)
outputs = model(
input_ids=batch["input_ids"].to(device), bbox=batch["bbox"].to(device),
attention_mask=batch["attention_mask"].to(device),
token_type_ids=batch["token_type_ids"].to(device), labels=labels
)
loss = outputs.loss
training_loss += loss.item()
predictions = outputs.logits.argmax(-1)
training_correct += (predictions == labels).float().sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
print("Training Loss:", training_loss / batch["input_ids"].shape)
training_accuracy = 100 * training_correct / len(train_data)
print("Training accuracy:", training_accuracy.item())
validation_loss = 0.0
validation_correct = 0
model.eval()
with torch.no_grad():
for batch in tqdm(valid_dataloader):
labels = batch["label"].to(device)
outputs = model(
input_ids=batch["input_ids"].to(device), bbox=batch["bbox"].to(device),
attention_mask=batch["attention_mask"].to(device),
token_type_ids=batch["token_type_ids"].to(device), labels=labels
)
loss = outputs.loss
validation_loss += loss.item()
predictions = outputs.logits.argmax(-1)
validation_correct += (predictions == labels).float().sum()
print("Validation Loss:", validation_loss / batch["input_ids"].shape)
validation_accuracy = 100 * validation_correct / len(valid_data)
print("Validation accuracy:", validation_accuracy.item())
这个示例代码展示了如何使用 LayoutLM 进行文档分类. 其中,需要注意的是,输入数据需要包含文本内容(words
),对应的边界框坐标(bbox
),以及类别标签(label
),且边界框坐标需要归一化到 0-1000 的范围内.
7. 常见问题
- LayoutLM 和 BERT 的区别是什么?
- BERT 主要处理文本信息,而 LayoutLM 同时处理文本、布局和视觉信息.
- LayoutLM 通过 2D位置嵌入 和 图像嵌入 来整合布局和视觉信息,使其能够更好地理解文档图像.
- 如何处理不同大小的文档图像?
- 通过 归一化边界框坐标,使 LayoutLM 能够处理各种大小的文档图像.
- LayoutLM 可以处理中文文档吗?
- LayoutLM 可以处理多语言文档,包括中文,前提是使用合适的 tokenizer 和预训练模型.
- 如何选择合适的预训练模型?
- Hugging Face Transformers 库提供了各种预训练的 LayoutLM 模型。您可以根据自己的任务和数据选择合适的模型.
- 如何提高 LayoutLM 的性能?
- 使用高质量的 OCR 结果.
- 使用与任务相关的 微调数据.
- 调整模型 参数,例如学习率和训练轮数.