deep-learning - 预训练Donut(文档理解变压器)

标签 deep-learning pytorch

我正在尝试复制 Donut 模型 ( https://github.com/clovaai/donut ) 的预训练(从头开始训练)阶段。论文中提到:

“在 在预训练阶段,Donut 通过预测下一个内容来学习如何阅读文本 通过联合调节图像和之前的文本上下文来调整单词。”

我不明白的是,“先前的文本上下文”如何用于预训练过程,而模型的输入只是图像。我应该如何更改微调代码才能预训练模型?

最佳答案

我也有同样的问题,看来预训练其实和微调是一样的。唯一的区别是,在预训练中,您训练模型来预测文档上的所有文本(或者,至少达到最大输出大小)。

由于他们使用MBart作为解码器,因此它支持多种语言。 https://github.com/clovaai/donut/blob/master/donut/model.py#L136

但是,他们只提供了针对英文或中文数据的微调模型。

我想用另一种语言对其进行预训练,我所要做的就是使用它们的 SynthDoG 生成一些图像和 json 文件。 ,它从一个简单的文本文件创建了训练、测试和验证文件夹,我必须更改配置,结果是 10k 图像,如下所示:Image for pretraining 这是该图像的 json:

{"file_name": "image_138.jpg", "ground_truth": "{\"gt_parse\": {\"text_sequence\": \"ă \\\"a testat cu succes o rachetă de crozieră tip Zircon în Marea Albă.4 goluriîn 36 de meciuri a adunat Costache pentru CFR Cluj în actuala edițieCea mai mare bază militară a Chin ei dispune de replici ale Turnului Eiffel și clădiril e prezidențiale din TaiwanOdată cu alegerea lui Donald Tru mp în 2016, contra atacul împotriva mla știnii de la Washington a î nceput.Zi neagră pe piaţa criptomo nedelor.Astrele vă recomandă să v ă gândiți mai mult la ceea ce contează cu adevărat. Dar cred că putem avea un minimum de pretenţii chiar şi atunci când suma pe care o pune m la bătaie când cumpărăm o locuinţă este ceva mai mi că.Ele includ lipsa mirosului sau/şi a gustulu i, durere în gât, febră.Cauza probabilă a incen diului a fost stabilita ca fiind coș de fum deteriorat.Se poate să fi găsit ceva re zonabil ca preţ de cumpărat şi să nu ezitaţi s-o faceţi c-o să v-aducă un confort casnic.Iniţiativa îi aparține Conf.Univ.Dr Victor Costache ,șeful secției d e chirurgie cardiovasculară Sf Constantin.Pepenele roşu nu mai trebuie consumat nici de persoanele sănătoase, dacă apar d ureri abdominale, balonare şi gaze instestinale ori afecţiuni ale sistemului digestiv\\\", spune dr.Oamenii președintelui, \\\"vânați\\\" de USR!Săgetătorii, mai ales în prima parte a zilei, manifestă m ai multă înțelegere pentru ceilalți de la care au mari așteptări, astfel încât, spre final de zi, lucrurile se pot l ămuri cu vorbe bune.Contra își cere toți banii! - E dreptul lui să-și ia banii pe ce-a muncit.Glo riile sau Snowball...Demersul, imposibil de dus până la capăt.Un tânăr în vârstă de 19 ani din Franța a prezentat pe ntru a intra într-un spital, unde avea programare, certificatul sanitar al lui Emma nuel Macron, care a ajuns în urmă cu câteva zile să fie distribu it în social media, relatează revista Le Point, citată de News.ro.Nu numai ca s-au redus in cent imetri (dimensiunea 44 are acum masuratorile a ceea ce odinioara era 46), dar moda insas\"}}"}

所有图像的json都在metadata.jsonl中,不用担心单词是否被分割, donut 在这个阶段正在学习字母的样子。 然后在 donut/dataset 中创建一个文件夹 base 并粘贴 SynthDog 创建的 3 个文件夹。 将 donut/config/train_cord.yaml 复制到 donut/config/base.yamldataset_name_or_paths 更改为您在数据集中创建的新文件夹 dataset_name_or_paths: ["dataset/base"] 并根据您的图像分辨率更改输入大小(必须是 320 的倍数),也基于您的 GPU VRAM,更高的分辨率需要 GPU 上更多的 VRAM,max_length 也会影响内存,在 16GB 卡上,我可以使用此配置进行训练,但速度有点慢:

resume_from_checkpoint_path: null # only used for resume_from_checkpoint option in PL
result_path: "./result"
pretrained_model_name_or_path: "naver-clova-ix/donut-base" # loading a pre-trained model (from moldehub or path)
dataset_name_or_paths: ["dataset/base"] # loading datasets (from moldehub or path)
sort_json_key: False # cord dataset is preprocessed, and publicly available at https://huggingface.co/datasets/naver-clova-ix/cord-v2
train_batch_sizes: [1]
val_batch_sizes: [1]
input_size: [1600, 1920] # when the input resolution differs from the pre-training setting, some weights will be newly initialized (but the model training would be okay)
max_length: 700
align_long_axis: False
num_nodes: 1
seed: 2022
lr: 3e-4
warmup_steps: 300 # 800/8*30/10, 10%
num_training_samples_per_epoch: 800
max_epochs: 30
max_steps: -1
num_workers: 8
val_check_interval: 1.0
check_val_every_n_epoch: 10
gradient_clip_val: 1.0
verbose: True

您可以尝试使用更少的max_epochs,这取决于您生成的图像数量,更少的图像更高的纪元..

运行python train.py --config config/base.yaml --exp_version "base"

经过几天的等待,您将在 donut/result/base 中获得基本 pretrained_model 将 donut/result/base/base 的内容上传到 Huggingface,您可以通过将其设置为来对其进行微调 pretrained_model_name_or_path: "your_huggingface_user/base" 或者只是它的路径(如果您不上传)

希望它有所帮助,这是我对 donut 的体验......如果您有任何其他问题,请询问。

关于deep-learning - 预训练Donut(文档理解变压器),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/75248458/

相关文章:

tensorflow - 如何在不同的 RNN 单元之间共享权重,这些单元在 Tensorflow 中输入不同的输入?

python - pytorch卷积层中第一个初始化的权重是多少

python - Pytorch的卷积损失从一开始就是0.0

tensorflow - 为什么 'dimension' 在机器学习领域有几个不同的含义?

python - 在 numpy 中获取 3D 张量的所有 2D 对角线

machine-learning - 为什么不训练部分时期呢?

machine-learning - 具有多 GPU 方法的 tensorflow 分布式训练混合

python - PyTorch 什么时候自动转换 Tensor dtype?

nlp - 如何将 One-Hot Encoding 值计算为实值向量?

python - Caffe编译时没有看到hdf5.h