LSTM文本分类(tensorflow)

2023-05-16

1)LSTM介绍
转载自https://www.csdn.net/article/2015-09-14/2825693
Gates:
这里写图片描述
输入变换:
这里写图片描述
状态更新:
这里写图片描述
使用图片描述类似下图:
这里写图片描述
输入
首先,让我们来定义输入形式。在lua中类似数组的对象称为表,这个网络将接受一个类似下面的这个张量表。
这里写图片描述

local inputs = {}
table.insert(inputs, nn.Identity()())  -- network input
table.insert(inputs, nn.Identity()())  -- c at time t-1
table.insert(inputs, nn.Identity()())  -- h at time t-1
local input = inputs[1]
local prev_c = inputs[2]
local prev_h = inputs[3]

计算gate值

locali2h=nn.Linear(input_size,4*rnn_size)(input)-- input to hiddenlocalh2h=nn.Linear(rnn_size,4*rnn_size)(prev_h)-- hidden to hiddenlocalpreactivations=nn.CAddTable()({i2h,h2h})-- i2h + h2h

这里写图片描述
运用非线性

-- gates
localpre_sigmoid_chunk=nn.Narrow(2,1,3*rnn_size)(preactivations)
localall_gates=nn.Sigmoid()(pre_sigmoid_chunk)
-- input
localin_chunk=nn.Narrow(2,3*rnn_size+1,rnn_size)(preactivations)
localin_transform=nn.Tanh()(in_chunk)

在非线性操作之后,我们需要增加更多的nn.Narrow,然后我们就完成了gates。

localin_gate=nn.Narrow(2,1,rnn_size)(all_gates)
localforget_gate=nn.Narrow(2,rnn_size+1,rnn_size)(all_gates)
localout_gate=nn.Narrow(2,2*rnn_size+1,rnn_size)(all_gates)

这里写图片描述
计算当前的Cell状态

-- previous cell state contribution
localc_forget=nn.CMulTable()({forget_gate,prev_c})
-- input contribution
localc_input=nn.CMulTable()({in_gate,in_transform})
-- next cell state
localnext_c=nn.CAddTable()({
 c_forget,
 c_input
})

实现hidden 状态计算

localc_transform=nn.Tanh()(next_c)
localnext_h=nn.CMulTable()({out_gate,c_transform})

这里写图片描述
实例:http://apaszke.github.io/assets/posts/lstm-explained/multilayer.lua
2)lstm实现文本分类
转载自 http://blog.csdn.net/u010223750/article/details/53334313?locationNum=7&fps=1
2.1原理
这里写图片描述
简单解释一下这个图,每个word经过embedding之后,进入LSTM层,这里LSTM是标准的LSTM,然后经过一个时间序列得到的t个隐藏LSTM神经单元的向量,这些向量经过mean pooling层之后,可以得到一个向量h,然后紧接着是一个简单的逻辑斯蒂回归层(或者一个softmax层)得到一个类别分布向量。
2.2tensorflow基础
a) variable_scope

1. 使用tf.Variable()的时候,tf.name_scope()和tf.variable_scope() 都会给 Variable 和 op 的 name属性加上前缀。 
2. 使用tf.get_variable()的时候,tf.name_scope()就不会给 tf.get_variable()创建出来的Variable加前缀。

b) tf.nn.embedding_lookup

tf.nn.embedding_lookup函数的用法主要是选取一个张量里面索引对应的元素。tf.nn.embedding_lookup(tensor, id):tensor就是输入张量,id就是张量对应的索引

c) tf.device

比如第一个GPU的名称为/gpu:0,第二个GPU名称为/gpu:1,以此类推。

d) variable转为python int
参见https://www.tensorflow.org/api_docs/python/tf/to_float

tf.to_int32(tf.variable)
tf.to_int64(tf.variable)
tf.to_float(tf.varialbe)

e)tf.nn.sparse_softmax_cross_entropy_with_logits
错误tensorflow:Only call sparse_softmax_cross_entropy_with_logits with named arguments
解决办法如下:
tf.nn.sparse_softmax_cross_entropy_with_logits(logits, train_labels_node))
改为tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=train_labels_node)
2.3代码
转载自 https://github.com/luchi007/RNN_Text_Classify


import tensorflow as tf
import numpy as np

class RNN_Model(object):

    def __init__(self,config,is_training=True):

        self.keep_prob=config.keep_prob
        self.batch_size=tf.Variable(0,dtype=tf.int32,trainable=False)

        num_step=config.num_step
        self.input_data=tf.placeholder(tf.int32,[None,num_step])
        self.target = tf.placeholder(tf.int64,[None])
        self.mask_x = tf.placeholder(tf.float32,[num_step,None])

        class_num=config.class_num
        hidden_neural_size=config.hidden_neural_size
        vocabulary_size=config.vocabulary_size
        embed_dim=config.embed_dim
        hidden_layer_num=config.hidden_layer_num
        self.new_batch_size = tf.placeholder(tf.int32,shape=[],name="new_batch_size")
        self._batch_size_update = tf.assign(self.batch_size,self.new_batch_size)

        #build LSTM network
        lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_neural_size,forget_bias=0.0,state_is_tuple=True)
        if self.keep_prob<1:
            lstm_cell =  tf.nn.rnn_cell.DropoutWrapper(
                lstm_cell,output_keep_prob=self.keep_prob
            )

        cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell]*hidden_layer_num,state_is_tuple=True)

        self._initial_state = cell.zero_state(tf.to_int32(self.batch_size),dtype=tf.float32)

        #embedding layer
        with tf.device("/gpu:0"),tf.name_scope("embedding_layer"):
            embedding = tf.get_variable("embedding",[vocabulary_size,embed_dim],dtype=tf.float32)
            inputs=tf.nn.embedding_lookup(embedding,self.input_data)

        if self.keep_prob<1:
            inputs = tf.nn.dropout(inputs,self.keep_prob)

        out_put=[]
        state=self._initial_state
        with tf.variable_scope("LSTM_layer"):
            for time_step in range(num_step):
                if time_step>0: tf.get_variable_scope().reuse_variables()
                (cell_output,state)=cell(inputs[:,time_step,:],state)
                out_put.append(cell_output)

        out_put=out_put*self.mask_x[:,:,None]

        with tf.name_scope("mean_pooling_layer"):

            out_put=tf.reduce_sum(out_put,0)/(tf.reduce_sum(self.mask_x,0)[:,None])

        with tf.name_scope("Softmax_layer_and_output"):
            softmax_w = tf.get_variable("softmax_w",[hidden_neural_size,class_num],dtype=tf.float32)
            softmax_b = tf.get_variable("softmax_b",[class_num],dtype=tf.float32)
            self.logits = tf.matmul(out_put,softmax_w)+softmax_b

        with tf.name_scope("loss"):
            self.loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits+1e-10,labels=self.target)
            self.cost = tf.reduce_mean(self.loss)

        with tf.name_scope("accuracy"):
            self.prediction = tf.argmax(self.logits,1)
            correct_prediction = tf.equal(self.prediction,self.target)
            self.correct_num=tf.reduce_sum(tf.cast(correct_prediction,tf.float32))
            self.accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32),name="accuracy")

        if not is_training:
            return

        self.globle_step = tf.Variable(0,name="globle_step",trainable=False)
        self.lr = tf.Variable(0.0,trainable=False)

        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
                                      config.max_grad_norm)
        optimizer = tf.train.GradientDescentOptimizer(self.lr)
        optimizer.apply_gradients(zip(grads, tvars))
        self.train_op=optimizer.apply_gradients(zip(grads, tvars))

        self.new_lr = tf.placeholder(tf.float32,shape=[],name="new_learning_rate")
        self._lr_update = tf.assign(self.lr,self.new_lr)

    def assign_new_lr(self,session,lr_value):
        session.run(self._lr_update,feed_dict={self.new_lr:lr_value})
    def assign_new_batch_size(self,session,batch_size_value):
        session.run(self._batch_size_update,feed_dict={self.new_batch_size:batch_size_value})

2.3结果
这里写图片描述
2.4此代码局限

a)二分类
b)基于标题的分类
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

LSTM文本分类(tensorflow) 的相关文章

随机推荐

  • SIP鉴权简介

    介绍 SIP提供了一个无状态 基于挑战的鉴权机制 xff0c 该机制基于HTTP的鉴权 任何时候一个UA或代理服务器收到一个请求 除CANCEL和ACK xff0c 都可以挑战请求的发起者要求其提供身份的保证 一旦发起者判定了身份 xff0
  • Java面试必背八股文[10]:RabbitMQ

    什么是 rabbitmq 采用 AMQP Advanced Message Queuing Protocol xff0c 高级消息队列协议 xff09 的一种消息队列技术 最大的特点就是消费并不需要确保提供方存在 xff0c 实现了服务之间
  • Java面试必背八股文[11]:计算机网络

    OSI与TCP IP各层的结构 xff1f 答 OSI分层 xff08 7层 xff09 xff1a 物理层 数据链路层 网络层 传输层 会话层 表示层 应用层 TCP IP分层 xff08 4层 xff09 xff1a 网络接口层 网际层
  • Java面试必背八股文[12]:计算机操作系统

    进程和线程有什么区别 xff1f 进程 xff08 Process xff09 是系统进行资源分配和调度的基本单位 xff0c 线程 xff08 Thread xff09 是CPU调度和分派的基本单位 xff1b 线程依赖于进程而存在 xf
  • 2022年总结

    2022年总结 人生的转折痛并快乐着愿岁月静好未来加油吧 人生的转折 2022年 人生的转折点了 xff0c 研究生毕业 xff0c 再也没有了那个 学生 的身份 xff0c 新的篇章 xff0c 如何续写 xff1f 2022年6月20日
  • 【进阶】"结构体嵌入共联体"在协议解析中的神操作!

    1 聊一聊 34 I was alone but not lonely 34 今天的文章话题引出来自bug技术交流群 xff0c 主要是想把这种协议解析和设计的方式分享给大家 xff01 2 正文部分 1 话题引出 bug技术交流群一个小哥
  • Linux中模拟GET、POST请求

    1 概述 在Linux系统中 xff0c 可以利用命令来模拟HTTP请求中的GET POST PUT等请求 xff0c 本文将阐述基于curl命令来模拟GET与POST请求 xff0c PUT DELETE等请求与POST类似 xff0c
  • yolo自带标注工具yolo_mark下载及使用说明

    官网写的比较详细 xff0c 下载参考 https github com AlexeyAB Yolo mark 双击运行windows命令脚本 xff0c 而不是exe 将要标注的样本路径 xff0c 写入train txt文件中 上面这个
  • C++语法(二十)常函数、常对象

    1 常函数 常函数无法修改成员变量 xff0c 除非这个成员变量用mutable修饰了 include lt iostream gt using namespace std class Person public void change c
  • rplidar_ros 报错:can‘t bind 和Operation Time Out的解决

    我使用的思岚A2的雷达在ros下运行 1 can t bind无法连接的错误 xff0c 一种是设备号不匹配引起的错误 xff0c 首先可以使用ll dev grep ttyUSB查看一下设备的dev号 xff0c 再检查一下rplidar
  • 串口通信与网络通信

    上一篇文章记录了使用C Winform开发串口通讯的上位机软件 xff0c 而笔者在整个职业经历中开发得较多的还是网络通讯软件 xff0c 通过以太网TCP IP UDP协议实现不同服务器应用程序之间数据传送与接收 xff1b 而随着公司业
  • Unity URP自学笔记四 ShaderGraph

    在ShaderGraph中自定义光照计算 xff0c 主要需要获取光照的颜色和方向 xff0c 这些需要自己通过脚本来获取 例如通过CustomFunction结点来处理 xff1a 下面创建了一个半兰伯特SubGraph xff0c 便于
  • [笔记]STM32基于HAL库实现STM32串口中断接收数据

    这里使用USART1串口 usart c中添加 xff08 1 xff09 添加全局变量 uint8 t USART1 Buff 100 61 0 接收帧缓存 xff0c 自己定义大小 uint8 t USART1 STA 61 0 boo
  • matlab--UDP发送接收

    m函数中UDP接收和发送 接收 ipA 192 168 0 5 portA 8080 ipB 192 168 0 3 portB 8080 handles udpB udp ipA portA LocalPort portB 远程ip 远程
  • mysql--日志

    转载自 xff1a https www cnblogs com f ck need u p 9001061 html 日志刷新 mysql gt FLUSH LOGS 错误日志 简介 错误日志记录了MySQL Server每次启动和关闭的详
  • osg--读写

    文件I O 命名规则 osgdb xxx 比如 osgdb osg osgdb jpeg 关联文件后缀和加载器 osgDB Registry instance gt addFileExtensionAlias jpeg jpeg osgDB
  • osg--几种效果

    billboards 适用于小草等的绘制 osg BillBoard继承自osg Geode 其下所有osg Drawable面向观察者 旋转行为通过setMode 设置 分别为 POINT ROT EYE 几何体z轴旋转到窗口y轴 POI
  • osg--提高效率

    多线程 OpenThreads Thread 虚函数 cancel run OpenThreads Mutex OpenThreads Barrier OpenThreads Condition 线程管理 GetNumberOfProces
  • torch在ubuntu16.04下的搭建(cuda9.0+cudnn7.0)

    希望外婆身体越来越好 参考 xff1a http blog csdn net chenhaifeng2016 article details 68957732 http www 52nlp cn E6 B7 B1 E5 BA A6 E5 A
  • LSTM文本分类(tensorflow)

    1 xff09 LSTM介绍 转载自https www csdn net article 2015 09 14 2825693 Gates xff1a 输入变换 xff1a 状态更新 xff1a 使用图片描述类似下图 xff1a 输入 首先