用tensorflow实现基本的word2vec

2023-11-08

"""Basic word2vec implementation through tensorflow"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from glob import glob
import collections
import math
import os
import sys
import time
import argparse
import random

import numpy as np
from six.moves import urllib
from six.moves import xrange
import tensorflow as tf


current_path = os.path.dirname(os.path.realpath(sys.argv[0]))

parser = argparse.ArgumentParser()
parser.add_argument(
    '--log_dir',
    type=str,
    default=os.path.join(current_path, 'output'),
    help='The log directory for TensorBoard summaries.')
FLAGS, unparsed = parser.parse_known_args()

# Create the directory for TensorBoard variables if there is not.
if not os.path.exists(FLAGS.log_dir):
  os.makedirs(FLAGS.log_dir)


files_list=glob("*************")
def read_data(path):
  vocabulary=[]
  with open(path, 'r') as f:
    for line in f:
      vocabulary += line.split()
  return vocabulary

# Step1: get the dictionary and reverse_dictionary
vocab_uni=[]
dictionary={}
dictionary['UNK'] = 0
for ii in files_list:
  # print("file is:",ii)
  vocabulary=read_data(ii)
  for voca in vocabulary:
    if voca in dictionary.keys():
      pass
    else:
      dictionary[voca]=len(dictionary)

  vocab_uni+=vocabulary
  del vocabulary
  vocab_uni=list(set(vocab_uni))
reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys()))

print("size of vocab is:",len(vocab_uni))

# Step 2: Build the dictionary and replace rare words with UNK token.
vocabulary_size = len(vocab_uni)+1 #50000

# 建立数据集,words是所有单词的列表,n_words是想建的字典中单词的个数
def build_dataset(words, n_words, dictionary):
  """Process raw inputs into a dataset."""
  #将所有低频单词设为UNK,个数先设为-1
  count = [['UNK', -1]]
  data = list()
  unk_count = 0
  for word in words:
    index = dictionary.get(word, 0)
    if index == 0:  # dictionary['UNK']
      unk_count += 1
    data.append(index)
  #记录UNK个数
  count[0][1] = unk_count

  return data, count



# Step 3: Function to generate a training batch for the skip-gram model.
def generate_batch(data,batch_size, num_skips, skip_window):

  global data_index
  global has_next
  assert batch_size % num_skips == 0
  assert num_skips <= 2 * skip_window
  batch = np.ndarray(shape=(batch_size), dtype=np.int32)
  labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
  span = skip_window + 1
  buffer = collections.deque(maxlen=span)
  # #建立一个结构为双向队列的缓冲区,大小不超过3
  if data_index + span > len(data):
    has_next=False
  # 如果索引超过了数据长度,则重新从数据头部开始
  buffer.extend(data[data_index:data_index + span])
  data_index += span    #将index向后移3位
  for i in range(batch_size // num_skips):
    context_words = [w for w in range(span) if w != 0]
    words_to_use = random.sample(context_words, num_skips)  
    # start_words=len(words_to_uselen)//2
    # words_to_use=context_words[start_words:]

    for j, context_word in enumerate(words_to_use):
      batch[i * num_skips + j] = buffer[0]  #在batch中存入当前单词
      labels[i * num_skips + j, 0] = buffer[context_word]
    if data_index == len(data):
      buffer.extend(data[0:span])
      data_index = span
      has_next=False
    elif data_index>len(data):
      has_next=False
    else:
      buffer.append(data[data_index])
      data_index += 1  #当前单词的索引向后移一位
  # Backtrack a little bit to avoid skipping words in the end of a batch
  data_index = (data_index + len(data) - span) % len(data)
  # 避免循环结束后刚好停在data尾部,以防下次运行该函数向后移动三位index时越界
  # print("batch is",batch)
  # print("labels is",labels)
  return batch, labels


batch_size = 1024
embedding_size = 128  # Dimension of the embedding vector.
skip_window = 5  # How many words to consider left and right.
num_skips = 5  # How many times to reuse an input to generate a label.
num_sampled = 10  # Number of negative examples to sample.

valid_size = 16  # Random set of words to evaluate similarity on.
valid_window = 20  # Only pick dev samples in the head of the distribution. #原为100
valid_examples = np.random.choice(valid_window, valid_size, replace=False)
print("valid examples",valid_examples)

graph = tf.Graph()

with graph.as_default():

  # Input data.
  with tf.name_scope('inputs'):
    train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
    train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
    valid_dataset = tf.constant(valid_examples, dtype=tf.int32)

  # Ops and variables pinned to the CPU because of missing GPU implementation
  with tf.device('/gpu:0'):
    # Look up embeddings for inputs.
    with tf.name_scope('embeddings'):
      embeddings = tf.Variable(
          tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
      embed = tf.nn.embedding_lookup(embeddings, train_inputs)

    # Construct the variables for the NCE loss
    with tf.name_scope('weights'):
      #initialization train parameters
      nce_weights = tf.Variable(
          tf.truncated_normal(
              [vocabulary_size, embedding_size],
              stddev=1.0 / math.sqrt(embedding_size)))
    #initialization bias
    with tf.name_scope('biases'):
      nce_biases = tf.Variable(tf.zeros([vocabulary_size]))

  # Compute the average NCE loss for the batch.

  with tf.name_scope('loss'):
    loss = tf.reduce_mean(
        tf.nn.nce_loss(
            weights=nce_weights,
            biases=nce_biases,
            labels=train_labels,
            inputs=embed,
            num_sampled=num_sampled,
            num_classes=vocabulary_size))

  # Add the loss value as a scalar to summary.
  tf.summary.scalar('loss', loss)

  # Construct the SGD optimizer using a learning rate of 1.0.
  with tf.name_scope('optimizer'):
    optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)

  #embedding normalization
  norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keepdims=True))
  normalized_embeddings = embeddings / norm
  #找到验证集中的id对应的embedding
  valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings,
                                            valid_dataset)
  #判断验证集和整个归一化的embedding的相似性
  similarity = tf.matmul(
      valid_embeddings, normalized_embeddings, transpose_b=True)

  # Merge all summaries.
  merged = tf.summary.merge_all()

  # Add variable initializer.
  init = tf.global_variables_initializer()

  # Create a saver.
  saver = tf.train.Saver()

# Step 5: Begin training.
num_steps = 10000
with tf.Session(graph=graph) as session:
  # Open a writer to write summaries.
  writer = tf.summary.FileWriter(FLAGS.log_dir, session.graph)

  # We must initialize all variables before we use them.
  init.run()
  print('Initialized')
  average_loss = 0
  step=0
  for path in files_list:
    # num_steps+=10001
    # start=num_steps-10001
    print("path is:",path)
    vocabulary=read_data(path)
    data_index=0
    has_next=True
    data, count = build_dataset(
        vocabulary, vocabulary_size, dictionary)
    # time.sleep(1)
    del vocabulary  # Hint to reduce memory.
    while step<3000:
      step+=1
      # for step in xrange(num_steps):
      #生成一个batch的训练数据
      batch_inputs, batch_labels = generate_batch(data,batch_size, num_skips,
                                                  skip_window)

      feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels}

      # Define metadata variable.
      run_metadata = tf.RunMetadata()

      _, summary, loss_val = session.run(
          [optimizer, merged, loss],
          feed_dict=feed_dict,
          run_metadata=run_metadata)
      average_loss += loss_val

      # Add returned summaries to writer in each step.
      writer.add_summary(summary, step)
      # Add metadata to visualize the graph for the last run.
      if step == (num_steps - 1):
        writer.add_run_metadata(run_metadata, 'step%d' % step)
      # compute average loss eval 20000 steps
      if step % 10000 == 0:
        if step > 0:
          average_loss /= step
        # The average loss is an estimate of the loss over the last 2000 batches.
        print('Average loss at step ', step, ': ', average_loss)
        average_loss = 0

      # Note that this is expensive (~20% slowdown if computed every 500 steps)
      if step % 50000 == 0:
        # 每10000步评估一下验证集和整个embeddings的相似性
        # 结果是验证集中每个词和字典中所有词的相似性
        sim = similarity.eval()
        for i in xrange(valid_size):
          valid_word = reverse_dictionary[valid_examples[i]]
          #因为两个向量相乘,值越小越相似(余弦定理),这里找出前8个最相似的词
          top_k = 8
          nearest = (-sim[i, :]).argsort()[1:top_k + 1]
          log_str = 'Nearest to %s:' % valid_word
          for k in xrange(top_k):
            # 根据id找到对应的word
            # print("reverse dictionary is:",reverse_dictionary)
            # print("near is",nearest)
            close_word = reverse_dictionary[nearest[k]]
            log_str = '%s %s,' % (log_str, close_word)
          print(log_str)
    final_embeddings = normalized_embeddings.eval()


# Step 6: 输出词向量
with open('word2vec_karate.txt', "w", encoding="UTF-8") as fW2V:
    fW2V.write(str(vocabulary_size) + ' ' + str(embedding_size) + '\n')
    for i in xrange(final_embeddings.shape[0]):
        sWord = reverse_dictionary[i]
        sVector = ''
        for j in xrange(final_embeddings.shape[1]):
            sVector = sVector + ' ' + str(final_embeddings[i, j])
        fW2V.write(sWord + sVector + '\n')
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

用tensorflow实现基本的word2vec 的相关文章

  • Cookie 和 Session 归纳

    首先介绍下基本概念 cookie是服务器通知客户端让其保存健值对的一种形式 客户端有了cookie之后 每次请求就会发送给服务器 每个cookie最大是4kb 服务器创建cookie 1Cookie cookie new Cookie 创建
  • 关于typedef的用法总结

    typedef的应用 typedef是C 语言中用于为现有数据类型指定替代名称的关键字 它主要用于用户定义的数据类型 当数据类型的名称在程序中使用变得稍微复杂时 以下是使用的一般语法 typedef
  • 删除卷与分页文件(虚拟内存文件)

    无法删除卷可能是由于这个磁盘中存在分页文件 虚拟内存文件 引起的
  • 软件程序如何运行的-简述

    开门见山 咱不说废话 你有没有想过 你写的程序 是如何在计算机中运行的吗 比如我们搞Java的 肯定写过这段代码 public class HelloWorld public static void main String args Sys
  • shell I/O重定向

    shell重定向 lt 改变标准输入 program lt file 可将program 的标准输入改为file tr d r lt dos file txt 以 gt 改变标准输出 program gt file 可将program的标准
  • win10如何校验文件哈希值

    转自 https jingyan baidu com article 67662997a9b06654d51b84a1 html 文件的哈希值可以用软件计算 算法一样 无须多讲 本文讲述如何用win10自带命令计算 右击开始 点击windo
  • 计算机网络-详细版

    鉴于有人需要离线版的PDF文档 这里给出本文章的PDF版本 下载地址如下 https pan itnxd cn 123Pan csdn share computer network pdf 一 计算机网络体系结构 0 脑图 1 计算机网络概
  • 记一次Tomcat日志分析:一个或多个listeners启动失败,更多详细信息查看对应的容器日志文件

    1 问题 我将一个应用 MicroStrategy 11 3 0000 13515 部署到Tomcat 然后 我点击start后报错 FAIL Application at context path MicroStrategy 11 3 0
  • MySQL8 EXPLAIN 命令输出的都是什么东西?这篇超详细!

    引子 小扎刚毕业不久 在一家互联网公司工作 由于是新人 做的也都是简单的CRUD 刚来的时候还有点不适应 做了几个月之后 就变成了熟练工了 左复制 右粘贴 然后改改就是自己的代码了 生活真美好 有一天 领导说他做的有个列表页面速度很慢 半天
  • LRU 最近最少使用算法

    LRU 最近最少使用算法 设计LRU Cash 数据结构 设计方法 代码实现 总结 百度百科 LRU是Least Recently Used的缩写 即最近最少使用 是一种常用的页面置换算法 选择最近最久未使用的页面予以淘汰 该算法赋予每个页
  • 装系统时提示 无法在驱动器0分区上安装windows

    先看提示 先看提示 先看提示 1 在重装系统时遇到一个问题 无法在驱动器0分区上安装windows 2 解决方法 1 在当前安装界面按住Shift F10调出命令提示符窗口 2 输入diskpart 按回车执行 3 进入DISKPART命令
  • 负数为什么要用补码来表示?

    上篇文章讲了 负数在计算机中是怎么存储的 看完之后 应该对原码 反码 补码有了基本的了解了 今天 我们深入探讨一下 为什么计算机中要用补码来表示负数 首先 我们应该清楚 原码是方便给人看的 看到一个数的原码 我们就能根据符号位和后边的二进制
  • Intellij多行同时缩进或者同时空格

    在使用JetBrains旗下的集成软件 如IDEA Pycharm PhpStorm Clion等时 通常需要整体向前或者向后缩进代码 以更加美观地编写代码 此时 可通过以下两个快捷键实现该功能 1 代码整体向后缩进 选中多行代码 按下ta
  • spring-boot后端解决跨域问题

    代码 import cn hutool log Log import cn hutool log LogFactory import com alibaba fastjson JSONObject import org springfram
  • ISO/OSI七层模型

    想要让两台PC进行通信 必须使用相同的信息交换规则 我们把计算机网络中用于规定信息的格式 以及如何发送和接受信息的一套规则称谓网络协议或者通信协议 我们为了减少网络设计的复杂 人们按功能将计算机网络划分为多个不同功能的层 网络体系结构就是网
  • 计算机组成原理笔记

    文章目录 一 计算机的基本组成 二 总线 2 1 总线控制 三 主存储器 3 1 RAM 3 2 存储器与CPU相连 3 3 存储器校验 3 4 提高存储器访问速度 3 5 cache 四 输入输出系统 4 1 I O接口 4 2 程序中断
  • GnuTLS recv error (-110): The TLS connection was non-properly terminated问题的解决方案

    我在使用git clone branch 3 4 depth 1 https github com opencv opencv git命令的时候 遇到如下问题 fatal unable to access https github com
  • Linux操作系统与Shell编程

    Linux是自由 开源的操作系统 安装在计算机的硬件之上 是用来操作计算机硬件和软件资源的系统软件 一般应用于专业的web服务器上 具有以下特性 Linux注重系统的安全性 对文件访问权限有严格设定 最高权限账户为root用户 可以操作一切
  • 字符编码和字符集有什么区别?Unicode是什么,和UTF-8是什么关系?你想知道的都在这篇文章了

    前言 想必大家编写代码时肯定和我一样 也遇到过汉字乱码的问题 特别是 有时候和上下游对接接口 不能统一编码格式的话 一堆乱码问题 让人头皮发麻 那么为什么会有这么多的乱码问题 什么是字符编码 什么是字符集 他们之间有什么区别和联系 什么是
  • 【编译原理】 CS143 斯坦福大学公开课 第一周:简介

    youtube 1 1 Introduction to Compilers and interpreters 1 1 Introduction to Compilers and interpreters 编译器解释器介绍 两种主要的实现编程

随机推荐

  • 冻结表格列PyQt

    QT有个官方的例子 Frozen Column Example 在Qt Creator例子查找即可 官方例子python版本 Frozen Column Example Qt for Python 不过官方python版应该是机器直接翻译的
  • C++ continue 语句

    C 中的 continue 语句有点像 break 语句 但它不是强迫终止 continue 会跳过当前循环中的代码 强迫开始下一次循环 对于 for 循环 continue 语句会导致执行条件测试和循环增量部分 对于 while 和 do
  • 十进制转换为二进制代码

    十进制转换为二进制代码 十进制转换为二进制 十进制如何转二进制 将该数字不断除以2直到商为零 然后将余数由下至上依次写出 即可得到该数字的二进制表示 以将数字21转化为二进制为例 当商为零时 将余数由下至上依次写出 即为21的二进制表示 i
  • SpringBoot整合框架——数据库

    目录 一 整合JDBC使用 1 1 SpringData简介 1 2 创建测试项目测试数据源 1 3 JDBCTemplate JdbcTemplate主要提供以下几类方法 1 4 测试 二 整合Druid
  • java 有限状态机_有限状态机的4种Java实现对比

    写在前面 2020年面试必备的Java后端进阶面试题总结了一份复习指南在Github上 内容详细 图文并茂 有需要学习的朋友可以Star一下 GitHub地址 https github com abel max Java Study Not
  • 【Redis】Redis安装与配置:

    文章目录 一 下载与安装 二 服务启动与停止 1 启动 2 设置后台运行 3 设置密码 解开注释 将默认密码foobared修改为你的 4 设置远程连接 一 下载与安装 redis https redis io download tar z
  • python详细安装教程-Pycharm及python安装详细教程(图解)

    首先我们来安装python 1 首先进入网站下载 点击打开链接 或自己输入网址https www python org downloads 进入之后如下图 选择图中红色圈中区域进行下载 2 下载完成后如下图所示 3 双击exe文件进行安装
  • sqli-labs Less-4

    本系列文章使用的靶场环境为sqli labs 环境下载地址 https github com Audi 1 sqli labs 持续跟新 一直到通过此靶场为止 1 判断注入类型 index php id 1 单引号回显正常 双引号会报错 然
  • Pyhon加载模块的两种方法

    一 在Python中添加 1 找到Settings 2 找到Project Interpreter 3 点击加号 4 在搜索栏搜索想要的模块 二 利用cmd安装 1 打开cmd 2 输入python 查看能否显示版本信息 不能的话需要配置环
  • APP 测试过程中缺陷总结

    1 拍照视频 问题1 视频拍照 文案和图标不一致 操作1 拍摄照片 点击拍摄视频 查看照片大图 确认 操作2 看系统是否存在两模式 定制和非定制 且都拥有这个视频和拍照功能 定制模式下切换到视频时 退出登录 登录到非定制版 2 上传 问题1
  • Log4j2 日志脱敏

    日志脱敏首先要搞清楚 影响的数据范围 是要全局支持日志脱敏 还是只针对部分代码 如果涉及到敏感数据的业务代码较少 建议写个数据脱敏工具类 在打印日志的时候调用 灵活可靠 影响范围小 一 第一种方案 全局方式 针对log4j2的日志脱敏实现方
  • MySql数据库基础--数据类型优化

    文章目录 数据类型的优化 各类型的特点 整数 实数 字符串 TEXT BLOB 日期时间 选择标识符 约束 数据类型的优化 优化原则 从小 更小通常更快 占用更小的磁盘空间 内存 cpu缓存 更少的cpu周期 从简 更少的cpu周期 整形比
  • vue 父子孙页面传值的多种方法

    父给子 第一种 props 缺点 只能一级一级的传值 子页面不能修改这个参数 父页面
  • 【单片机】UART、I2C、SPI、TTL、RS232、RS422、RS485、CAN、USB、SD卡、1-WIRE、Ethernet等常见通信方式

    在单片机开发中 UART I2C RS485等普遍在用 这里做一个简单的介绍 UART通用异步收发器 UART口指的是一种物理接口形式 硬件 UART是异步 指不使用时钟同步 依靠帧长进行判断 全双工 收发可以同时进行 串口总线 它比同步串
  • android 常用机型尺寸_Android中图片大小与各种hdpi

    前言 大家都知道开发android会涉及到UI的涉及 一般都是给到通用的分辨率进行设计 但是具体适配是需要代码控制的 由于网上分辨率dp的文章实在太多 对这些不了解的朋友可以去自行百度 这里主要是对UI的设计过程与原理进行一个简要的分析 术
  • 【数据结构】链表

    数据结构 链表 1 链表的概念及结构 链表是一种物理存储单元上非连续 非顺序的存储结构 数据元素的逻辑顺序是通过链表中的指针链接次序实现的 链表由一系列结点 链表中每一个元素称为结点 组成 结点可以在运行时动态生成 每个结点包括两个部分 一
  • Qt 安装包制作(基于Qt Installer Framework)

    目录 下载 Qt Installer Framework 程序打包发布 创建安装包程序 下载 Qt Installer Framework 官方下载 http download qt io official releases qt inst
  • 此URL不支持Http方法GET

    出现删除问题的解决办法是 需要把代码中的super注释掉 super doGet req resp super doPost req resp 原来报错的代码 修改之后的代码 记得修改之后 重启一下tomcat就可以了
  • java连接多个mysql_Java连接到多个数据库

    我正在创建一个连接到多个数据库的Java应用程序 用户将能够从下拉框中选择要连接的数据库 然后 程序通过将名称传递给创建初始上下文的方法来连接到数据库 以便它可以与oracle Web逻辑数据源进行通信 public class dbMai
  • 用tensorflow实现基本的word2vec

    Basic word2vec implementation through tensorflow from future import absolute import from future import division from fut