如果您通过常见单词的出现来衡量相似性,那么您甚至不需要 Spacy:只需使用字数统计对文本进行矢量化并馈送到任何聚类算法即可。AgglomerativeClustering
就是其中之一——对于大型数据集来说,它的时间效率不是很高,但它是高度可控的。您需要为数据集调整的唯一参数是distance_threshold
:越小,簇就越多。
对文本进行聚类后,您可以连接每个聚类中的所有唯一单词(或者做一些更聪明的事情,具体取决于您要解决的最终问题)。整个代码可能如下所示:
texts = '''yellow color
yellow color looks like
yellow color bright
red color okay
red color blood'''.split('\n')
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.preprocessing import Normalizer, FunctionTransformer
from sklearn.cluster import AgglomerativeClustering
from sklearn.pipeline import make_pipeline
model = make_pipeline(
CountVectorizer(),
Normalizer(),
FunctionTransformer(lambda x: x.todense(), accept_sparse=True),
AgglomerativeClustering(distance_threshold=1.0, n_clusters=None),
)
clusters = model.fit_predict(texts)
print(clusters) # [0 0 0 1 1]
from collections import defaultdict
cluster2words = defaultdict(list)
for text, cluster in zip(texts, clusters):
for word in text.split():
if word not in cluster2words[cluster]:
cluster2words[cluster].append(word)
result = [' '.join(wordlist) for wordlist in cluster2words.values()]
print(result) # ['yellow color looks like bright', 'red color okay blood']
仅当常用词不够并且您想要捕获语义相似性时,您才需要 Spacy 或任何其他具有预训练模型的框架。整个管道只会改变一点点。
# !python -m spacy download en_core_web_lg
import spacy
import numpy as np
nlp = spacy.load("en_core_web_lg")
model = make_pipeline(
FunctionTransformer(lambda x: np.stack([nlp(t).vector for t in x])),
Normalizer(),
AgglomerativeClustering(distance_threshold=0.5, n_clusters=None),
)
clusters = model.fit_predict(texts)
print(clusters) # [2 0 2 0 1]
您会看到这里的聚类显然是不正确的,因此 Spacy 词向量似乎不适合这个特定问题。
如果您想使用预训练模型来捕获文本之间的语义相似性,我建议您使用Laser
反而。它明确基于句子嵌入,并且是高度多语言的:
# !pip install laserembeddings
# !python -m laserembeddings download-models
from laserembeddings import Laser
laser = Laser()
model = make_pipeline(
FunctionTransformer(lambda x: laser.embed_sentences(x, lang='en')),
Normalizer(),
AgglomerativeClustering(distance_threshold=0.8, n_clusters=None),
)
clusters = model.fit_predict(texts)
print(clusters) # [1 1 1 0 0]