由于某种原因,当数据集相交时,整理函数会感到困惑,因为有额外的行,所以它不知道如何合并内容?我修复它的方法是只保留我想要的列:
# -- Get data set
# remove_columns = ['text', 'timestamp', 'url']
keep_col = ['text']
# keep the strings in dataaset.column_names that intersect with keep_col str list, one liner
print('-- interleaving datasets')
datasets = [load_dataset(path, name, streaming=True, split="train").with_format("torch") for path, name in zip(path, name)]
[print(f'{dataset.description=}') for dataset in datasets]
dataset = interleave_datasets(datasets, probabilities)
remove_columns = [col for col in dataset.column_names if col not in keep_col]
print(f'{dataset=}')
batch = dataset.take(batch_size)
而且如果您知道所需的文本字段(假设"text"
由于它很常见):
def collate_tokenize(data):
print(f'{data[0]=}')
text_batch = [element["text"] for element in data]
tokenized = tokenizer(text_batch, padding='longest', truncation=True, return_tensors='pt')
return tokenized
data_loader = DataLoader(tokenized_batch, shuffle=False, batch_size=8, num_workers=0, drop_last=False, collate_fn=collate_tokenize)
# data_loader = DataLoader(tokenized_batch, shuffle=False, batch_size=8, num_workers=0, drop_last=False)
# num_batches = len(list(data_loader))
batch = next(iter(data_loader))
print(f'{batch=}')
print('Done!\a')
完整代码:
def test_interleaved_data_set_2_data_loader():
""" https://colab.research.google.com/drive/1QWDhA6Q64qijXYnwIGn63Aq9Eg5qt8tQ#scrollTo=Wjyy6QYimvIm """
remove_columns = []
# -- Get probe network
from datasets import load_dataset
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
probe_network = GPT2LMHeadModel.from_pretrained("gpt2")
device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
probe_network = probe_network.to(device)
from datasets import interleave_datasets
path, name = ['c4', 'wikitext'], ['en', 'wikitext-103-v1']
probabilities = [1.0/len(path)] * len(path)
batch_size = 512
# -- Get data set
# remove_columns = ['text', 'timestamp', 'url']
keep_col = ['text']
# keep the strings in dataaset.column_names that intersect with keep_col str list, one liner
print('-- interleaving datasets')
datasets = [load_dataset(path, name, streaming=True, split="train").with_format("torch") for path, name in zip(path, name)]
[print(f'{dataset.description=}') for dataset in datasets]
dataset = interleave_datasets(datasets, probabilities)
remove_columns = [col for col in dataset.column_names if col not in keep_col]
print(f'{dataset=}')
batch = dataset.take(batch_size)
# - Prepare functions to tokenize batch
def preprocess(examples):
return tokenizer(examples["text"], padding="max_length", max_length=128, truncation=True, return_tensors="pt")
def map(batch):
return batch.map(preprocess, batched=True, remove_columns=remove_columns)
# tokenized_batch = batch.map(preprocess, batched=True, remove_columns=remove_columns)
tokenized_batch = map(batch)
print(f'{next(iter(tokenized_batch))=}')
# -- Get data loader
from torch.utils.data import DataLoader, Dataset
# def collate_tokenize(data):
# print(f'{data[0]=}')
# text_batch = [element["text"] for element in data]
# tokenized = tokenizer(text_batch, padding='longest', truncation=True, return_tensors='pt')
# return tokenized
# data_loader = DataLoader(tokenized_batch, shuffle=False, batch_size=8, num_workers=0, drop_last=False, collate_fn=collate_tokenize)
data_loader = DataLoader(tokenized_batch, shuffle=False, batch_size=8, num_workers=0, drop_last=False)
# num_batches = len(list(data_loader))
batch = next(iter(data_loader))
print(f'{batch=}')
print('Done!\a')