为了加快性能,我研究了 pytorch分布式数据并行并尝试将其应用于变压器Trainer.
The DDP 的 pytorch 示例指出这应该at least更快:
DataParallel是单进程、多线程,只能在单机上运行,而DistributedDataParallel是多进程,可以用于单机和多机训练。即使在单台机器上,DataParallel 通常也比 DistributedDataParallel 慢,因为跨线程的 GIL 争用、每次迭代复制模型以及分散输入和收集输出带来的额外开销。
我的 DataParallel 训练器如下所示:
import os
from datetime import datetime
import sys
import torch
from transformers import Trainer, TrainingArguments, BertConfig
training_args = TrainingArguments(
output_dir=os.path.join(path_storage, 'results', "mlm"), # output directory
num_train_epochs=1, # total # of training epochs
gradient_accumulation_steps=2, # for accumulation over multiple steps
per_device_train_batch_size=4, # batch size per device during training
per_device_eval_batch_size=4, # batch size for evaluation
logging_dir=os.path.join(path_storage, 'logs', "mlm"), # directory for storing logs
evaluate_during_training=False,
max_steps=20,
)
mlm_train_dataset = ProteinBertMaskedLMDataset(
path_vocab, os.path.join(path_storage, "data", "uniparc", "uniparc_train_sorted.h5"),
)
mlm_config = BertConfig(
vocab_size=mlm_train_dataset.tokenizer.vocab_size,
max_position_embeddings=mlm_train_dataset.input_size
)
mlm_model = ProteinBertForMaskedLM(mlm_config)
trainer = Trainer(
model=mlm_model, # the instantiated
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)