使用OrderedDict
构造vocab
时会按照从大到小的排序来构造token,因此借用这个特点:
- 构造一个列表,专门保存
vocab
中每一个单词的频数,此时列表的下标位置与vocab中下标位置是一一对应的
- 借助
vocab
将文本->token
- 直接用列表的
token
对应到频数列表即可得到对应token的频数值
注意:由于有时需要构造特殊字符,如<UNK>
,而且在构造时会直接放到最前面,因此我们在对应频数的列表时,需要把特殊字符构造的下标删除掉
前面的内容
Torchtext 0.12+新版API学习与使用示例(1)
Torchtext 0.12+ API构造训练用DataLoader与词向量的Embedding(2)
示例代码
from torchtext.vocab import vocab
from collections import Counter, OrderedDict
from torch.utils.data import Dataset, DataLoader
from torchtext.transforms import VocabTransform
import numpy as np
class TextDataSet(Dataset):
def __init__(self, text_list):
"""
使用新版API的一个简单的TextDataSet
:param text_list: 语料的全部句子
"""
# 这里使用频数排序构造 torchtext 的 vocab
total_word_list = []
for _ in text_list: # 将嵌套的列表([[xx,xx],[xx,xx]...])拉平 ([xx,xx,xx...])
total_word_list += _.split(" ")
counter = Counter(total_word_list) # 统计计数
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True) # 构造成可接受的格式:[(单词,num), ...]
ordered_dict = OrderedDict(sorted_by_freq_tuples)
# 开始构造 vocab
specials = ["<UNK>", "<SEP>"]
self.specials = specials # 特殊字符需要记录下来!
my_vocab = vocab(ordered_dict, specials=specials) # 单词转token,specials里是特殊字符,可以为空
vocab_transform = VocabTransform(my_vocab)
# 开始构造DataSet
self.text_list = text_list # 原始文本
self.vocab_transform = vocab_transform
self._len = len(text_list) # 文本量
# =============前面的都是套路,真正有用的就下面一行,前面的部分有疑问请参考前面的博客内容==================
self.freq = [i[1] for i in sorted_by_freq_tuples] # 记录频数统计
def __getitem__(self, id_index): # 每次循环的时候返回的值
sentence = self.text_list[id_index]
word_ids = self.vocab_transform(sentence.split(' '))
freq = np.take(self.freq, np.array(word_ids) - len(self.specials))
return word_ids, freq
def __len__(self):
return self._len
def main():
sentence_list = [ # 假设这是全部的训练语料
"nlp is natural language processing strives",
"nlp build machines that understand",
"nlp model respond to text or voice data and respond with text",
]
text_dataset = TextDataSet(sentence_list) # 构造 DataSet
data_loader = DataLoader(text_dataset, batch_size=1) # 将DataSet封装成DataLoader
for sentence, word_freq in data_loader:
print("====================================")
print("原句是:", sentence)
print("每个单词的频数:", word_freq)
if __name__ == '__main__':
main()