HAN论文模型代码复现与重构

2023-05-16

论文简介

本文主要介绍CMU在2016年发表在ACL的一篇论文:Hierarchical Attention Networks for Document Classification及其代码复现。

该论文是用于文档级情感分类(document-level sentiment classification)的,其模型架构如下:
在这里插入图片描述
该模型称为层次注意力模型(Hierarchical Attention Network),根据作者所述,

  • 层次是指:句子由单词组成,文档由句子组成,据此可以构建一个自下而上的层次结构。

  • 注意力是指:组成某个句子的单词对该句子的情感倾向的贡献是不同的,通常来说,形容词的贡献(如good)就比名词(如book)更大;同理,组成文档的句子对该文档的情感倾向的贡献也不同,例如某些句子可能仅仅是陈述事实,而另一些句子则很明显地表达出了自己的观点。据此,作者提出使用注意力机制来挖掘句子和文档中对情感分类比较重要的部分(btw,注意力机制比较成熟的最早的应用是Google发表的Attention is All you Need一文中)。

对词嵌入进行编解码的方式无非是双向GRU或CNN等,此处不再赘述。需要注意的是,该模型中的注意力机制分为两个部分,分别是word attention和sentence attention,即分别在单词和句子上应用注意力机制,可视化结果如下:
在这里插入图片描述
可以看出注意力机制的可视化结果高亮出了情感极性比较强的单词。例如左边带delicious的评论预测结果为4分(较好),带terrible的评论的预测结果为0分(极差)。
由此也可说明注意力机制是有效的。

代码复现及重构

(显然这篇几年前的论文的代码不是我写的)
本文参考了github上对该模型的复现代码:textClassifier,源代码就不详细解释了,稍有复杂的也就是数据处理部分,源码实现将训练data设为三维的,并在词嵌入后喂给了HAN模型。

考虑到源代码结构不是很清晰,也无法自定义输入的词嵌入的维度和训练数据集,因此本文对该代码进行了重构。

首先说明Python版本和依赖的库:

Python >= 3.6
numpy
pandas
re
bs4
pickle
sklearn
gensim
nltk
keras
tensorflow

Python版本需要大于3.6,至于其他库的话,只要版本不太落后一般都没问题

下面详细介绍改动的部分。

参数选项

原文没有提供参数选项,如果要输入不同维度的词嵌入文件,则每次都要修改源代码,十分不便,为此, 我在重构时加入了参数选项,主要代码如下:

parser = argparse.ArgumentParser('HAN')
parser.add_argument('--full_data_path', '-d', 
				help='Full path of  data', default=FULL_DATA_PATH)
parser.add_argument('--processed_pickle_data_path', '-D', 
				help='Full path of processed pickle data', default=PROCESSED_PICKLE_DATA_PATH)
parser.add_argument('--embedding_path', '-s', 
				help='The pre-trained embedding vector', default=EMBEDDING_PATH)
parser.add_argument('--model_path', '-m', help='Full path of  model', default=MODEL_PATH)
parser.add_argument('--epoch', '-e', help='Epochs', type=int, default=EPOCH)
parser.add_argument('--batch_size', '-b', help='Batch size', type=int, default=BATCH)
parser.add_argument('--training_data_ready', '-t', 
				help='Pass when training data is ready', action='store_true')
parser.add_argument('--model_ready', '-M', 
				help='Pass when model is ready', action='store_true')
parser.add_argument('--verbosity', '-v', 
				help='verbosity, stackable. 0: Error, 1: Warning, 2: Info, 3: Debug', action='count')
parser.description = 'Implementation of HAN for Sentiment Classification task'
parser.epilog = "Larry King@https://github.com/Larry955"

相应的变量定义在han_config.py文件中。

详细参数说明如下:

  • –full_data_path, 要输入的训练文件的路径,该文件必须为tsv格式
  • –processed_pickle_data_path, 已经处理过的数据集的路径
  • –embedding_path, 预训练词向量文件的路径
  • –model_path, 保存的模型的路径
  • –epoch, epoch个数
  • –batch_size, batch size
  • –training_data_ready, 数据集是否已被处理过,显式输入该参数时表明数据集已被处理过,否则会报错
  • –model_ready, 模型是否已保存好,显式输入该参数时表明模型已被保存,否则会报错
  • –verbosity, emmmm…

假设该文件为HAN. py,那么输入

python HAN.py --help

可得:
在这里插入图片描述
输入

python HAN.py --full_data_path=train_data.tsv --embedding_path=GoogleNews-vectors-negative300.bin --epoch=20

表示数据集的路径为train_data.tsv,预训练词嵌入文件为GoogleNews,epoch为20。
输入

python HAN.py --training_data_ready --model_ready

表示训练集和模型都已经准备好,可以直接加载。

词嵌入文件解析

原代码中只能解析glove词嵌入,并且词嵌入维度固定300维,我在重构时对词嵌入文件进行了简单的解析,使得模型可以接受不同的词嵌入文件(目前支持glove和GoogleNews两种),并能根据文件名提取出词嵌入的维度。主要代码如下:

emb_file_flag = ''
embedding_dim = 0

if embedding_path.find('glove') != -1:    
    emb_file_flag = 'glove'     # pre-trained word vector is glove    
    embedding_dim = int(((embedding_path.split('/')[-1]).split('.')[2])[:-1])
elif embedding_path.find('GoogleNews-vectors-negative300.bin') != 
-1:    
    emb_file_flag = 'google'    # pre-trained word vector is GoogleNews    
    embedding_dim = 300

得到词嵌入文件和维度后,再根据emb_file_flag针对不同的文件获取词向量:

embeddings_index = {}
if emb_file_flag == 'glove':    
    f = open(os.path.join(embedding_path), encoding='utf-8')    
    for line in f:        
        values = line.split()        
        word = values[0]        
        vec = np.asarray(values[1:], dtype='float32')        
        embeddings_index[word] = vec    
    f.close()
elif emb_file_flag == 'google':    
    wv_from_bin = KeyedVectors.load_word2vec_format(emb_path, 
binary=True)    
    for word, vector in zip(wv_from_bin.vocab, wv_from_bin.vectors):        
        vec = np.asarray(vector, dtype='float32')        
        embeddings_index[word] = vec

示例:

python HAN.py  --embedding_path=GoogleNews-vectors-negative300.bin  # pre-trained word vector file is GooleNews with 300d
python HAN.py --embedding_path=glove.6B.100d.txt    # pre-trained file is glove with 100d
python HAN.py --embedding_path=glove.6B.200d.txt    # 200d

保存已训练数据集及模型

原代码中,每次运行时都要对数据集进行处理,并且要重新训练模型,这对于百万级文档数据集而言十分耗时,为此,我在重构时设置了相应的参数选项,从而能通过直接加载保存的文件已避免多次训练,大大降低训练时间。代码如下:

  • 保存和加载已训练数据集
if is_training_data_ready:    
    with open(pickle_path, 'rb') as f:        
        # print('data ready')        
        data, labels, word_index = pickle.load(f)    
    f.close()
else:    
    data, labels, word_index = process_data(data_path)    
    with open(pickle_path, 'wb') as f:        
        pickle.dump((data, labels, word_index), f)    # save trained dataset
    f.close()
    
# Generate data for training, validation and test
x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.1, random_state=1)
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.1, random_state=1)
  • 保存和加载已训练模型
if is_model_ready:    
    # print('model ready')    
    model = load_model(model_path, custom_objects={'AttLayer': 
AttLayer})
else:    
# Generate embedding matrix consists of embedding vector    
    embedding_matrix = create_emb_mat(embedding_path, word_index, 
embedding_dim)    # Create model for training    
    model = create_model(embedding_matrix)    
    model.save(model_path)  # save model

需要说明的是,由于该模型中自定义了不在keras.layers中的层(AttLayer),因此直接load_model时会报错:github:keras/issues/#8612,为解决该问题,可参考我的另一篇博客:
使用keras调用load_model时报错ValueError: Unknown Layer:LayerName

添加函数和程序入口

原代码中只有一个数据预处理函数clean_str和一个类AttLayer,其余部分混杂其间,导致代码结构混乱,不易理解,为此,我在重构时将各项功能以函数形式封装,并添加主程序入口和注释,大大提升了代码的可读性。此处不再赘述。

实验结果

这是在IMDB二分类数据集上进行的实验,共25000条评论,train/val/test的划分为8/1/1,epoch为10,优化函数为rmsprop。
在这里插入图片描述

总结

这次重构基本把原代码核心功能(模型相关代码、注意力层AttLayer)以外的部分改得面目全非了,添加了上述功能后,跑模型时可以输入自己想要的信息,避免在源代码上进行修改,具有更高的弹性和可读性,和原来相比好了很多。重构后的代码见my github-HAN

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

HAN论文模型代码复现与重构 的相关文章

  • Django自带的加密算法及加密模块

    Django 内置的User类提供了用户密码的存储 验证 修改等功能 xff0c 可以很方便你的给用户提供密码服务 默认的Ddjango使用pbkdf2 sha256方式来存储和管理用的密码 xff0c 当然是可以自定义的 Django 通
  • 如何在Python中使用“ with open”打开多个文件?

    我想一次更改几个文件 xff0c 前提是我可以写入所有文件 我想知道我是否可以将多个打开调用与with语句结合with xff1a try with open 39 a 39 39 w 39 as a and open 39 b 39 39
  • 工业控制领域的期刊

    我们都知道目前做控制的大体分两大类人 xff0c 一类是做纯控制理论的 xff0c 主要是跟数学打交道 xff1b 另一类是做控制理论在各个行业的应用的 xff0c 其中包括电力系统 xff0c 机器人 xff0c 智能交通 xff0c 航
  • VNC 灰屏

    用vnc连接服务器的时候 xff0c 出现了灰屏 xff0c xff08 在xshell可以正常运行 xff09 上面会显示三个checkbox xff1a Accept clipboard from viewers Send clipbo
  • Ubuntu卸载python3.6

    注意 xff1a 这里说一下 xff0c 系统自带的python3 6可别乱删 xff0c 这个是我自己下载的python3 6 若你们有想卸载系统自带的python3 6 xff0c 可千万别去卸载 xff01 一般会开机都开不起 xff
  • 深度学习之BP神经网络

    深度学习之BP神经网络 BP xff08 Back Propagation xff09 网络是1986年由Rumelhart和McCelland为首的科学家小组提出 xff0c 是一种按误差逆传播算法训练的多层前馈网络 它的学习规则是使用最
  • 【ROS】源码分析-消息订阅与发布

    说明 本文通过NodeHandle subscribe和Publication publish 源码作为入口 xff0c 来分析PubNode SubNode之间是网络连接是如何建立的 xff0c 消息是如何发布的 xff0c topic队
  • Opencv-cvtColor

    cvtColor不是cv的成员 头文件的问题 include lt opencv2 opencv hpp gt 这个就可以
  • java听课笔记——9.25

    记录今天所学的东西 xff1a 1 Random 用于随机生成一个值 xff0c 可以有限定范围 xff0c 没有尝试过不设限制的随机 用法如下 xff1a Random random 61 new Random int temp 61 r
  • java听课笔记——10.09

    1 局部变量和全局变量 xff1a 2 匿名内部类比较和外部比较 匿名内部类的比较 xff0c 即在需要进行比较的类名后加上implements comparator lt 类名 gt 然后 xff0c 使用sort xff0c 对于sor
  • java听课笔记——10.10

    1 String与常量池 xff1a 常量池是java中的一个存储常量的存储器 xff0c 栈是一个临时的存储器 xff0c 在递归的时候比较明显 xff0c 函数的运行压缩在栈里 String str3 61 new String 34
  • Java听课笔记——10.30

    感觉今天没讲什么东西唉 一开始 xff0c 解释了一下ArrayList里的每个元素如果不进行类型约束的话 自然赋值为Object类 xff0c 而且是兼收并蓄的 同时讲了使用迭代器对ArrayList数组进行遍历 xff0c 直接上代码
  • 如何在Python中声明一个数组?

    如何在Python中声明数组 xff1f 我在文档中找不到任何对数组的引用 1楼 这个怎么样 gt gt gt a 61 range 12 gt gt gt a 0 1 2 3 4 5 6 7 8 9 10 11 gt gt gt a 7
  • openrave0.9安装遇到依赖问题及解决流程

    问题 cmake 时输出下面的失败信息 xff0c 虽然最后可以make install xff08 其实就是拷贝了库文件 xff09 安装上 xff0c 但是由于过程中有些步骤失败 xff0c 导致执行时缺少一些库文件 xff0c 无法执
  • Python入门--一篇搞懂什么是类

    写一篇Python类的入门文章 xff0c 在高级编程语言中 xff0c 明白类的概念和懂得如何运用是必不可少的 文章有点长 xff0c 3000多字 Python是面向对象的高级编程语言 xff0c 在Python里面 一切都是对象 xf
  • SQL Server访问远程数据库--使用openrowset/opendatasource的方法

    一 使用openrowset opendatasource前首先要启用Ad Hoc Distributed Queries xff0c 因为这个服务不安全SqlServer默认是关闭的 SQL Server 阻止了对组件 39 Ad Hoc
  • 我的2014碎碎念—学习篇、实习篇、工作篇、生活篇

    继去年作了一次年度总结过后 xff0c 我就发誓说以后每年年末都要做一次总结 xff0c 这对自己是非常有帮助的 xff0c 无奈由于天性懒散 xff0c 2015年都过去好几天了 xff0c 才花了点心思整理下自己在过去一年里的所得所失
  • 百度2014研发类校园招聘笔试题解答

    先总体说下题型 xff0c 共有3道简答题 xff0c 3道算法编程题和1道系统设计题 xff0c 题目有难有易 xff0c 限时两小时完成 一 简答题 动态链接库和静态链接库的优缺点轮询任务调度和可抢占式调度有什么区别 xff1f 列出数
  • CSDN-markdown语法之如何插入图片

    目录 图片上传方式 插入在线图片插入本地图片图片链接方式 行内式图片链接参考式图片链接几个问题探讨 问题1 xff1a 图片上传和图片链接两种方式的区别 问题2 xff1a Markdown中如何指定图片的高和宽 xff1f 问题3 xff
  • 京东2013校园招聘软件研发笔试题

    时间 xff1a 2012 9 11 地点 xff1a 川大 我只能说第一家公司 xff0c 不是一般的火爆 不得不吐槽一下 xff1a 京东宣讲完全没有计划 xff0c 只看到个下午两点半宣讲 xff0c 结果跑过去 xff0c 下午两点

随机推荐