手把手教你:基于Django的新闻文本分类可视化系统(文本分类由bert实现)

2023-10-31

系列文章




一、项目简介

本文主要介绍如何使用python语言,基于bert的文本分类和Django的网站设计实现一个:基于Django和bert的新闻文本分类可视化系统,如果有毕业设计或者课程设计需求的同学可以参考本文。本项目同时使用了深度学习框架TensorFlow 1.X的版本,IDE为pycharm。完整代码在最下方,想要先看源码的同学可以移步本文最下方进行下载。

博主也参考过文本分类相关模型的文章,但大多是理论大于方法。很多同学肯定对原理不需要过多了解,只需要搭建出一个可视化系统即可。

也正是因为我发现网上大多的帖子只是针对原理进行介绍,功能实现的相对很少。

如果您有以上想法,那就找对地方了!


不多废话,直接进入正题!

二、任务介绍

本次任务是一个较为复杂的新闻文本分类的任务,首先需要使用bert模型对新闻文本进行分类,然后使用Django构建一个文本分类结果查询的可视化系统。

我们的任务是要构建一个模型,任意输入一篇新闻文章,可以将新闻文本分为以下几类:

label: ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']

三.界面简介

系统完成后界面如下:

  • 页面一:一个文本输入界面,可以将需要分类的新闻文本写入对话框。
    页面一
  • 页面二:根据输入的文本调用后台模型进行预测并分类,我这里任意找了一个娱乐新闻,可以看到文本被正确分类了。

模型分类结果

四.数据简介

本次使用的数据为标注后的文本,共计10类:['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐'],见下图:

标注的文本数据


五、代码功能介绍

1.依赖环境集IDE

本项目使用的是anaconda的jupyter notebook编译环境,如不清楚如何使用的同学可以参考csdn上其他博主的基础教程,这里就不进行赘述。

tensorflow 1.9.0以上
sklearn
pandas
python3

2.数据处理

  • 我们先通过脚本将数据集分为:train.tsvdev.tsvtest.tsvpre_test.tsv,四部分
class TextProcessor(object):
    """按照InputExample类形式载入对应的数据集"""

    """load train examples"""
    def get_train_examples(self, data_dir):
        return self._create_examples(
            self._read_file(os.path.join(data_dir, "train.tsv")), "train")

    """load dev examples"""
    def get_dev_examples(self, data_dir):
        return self._create_examples(
            self._read_file(os.path.join(data_dir, "dev.tsv")), "dev")

    """load test examples"""
    def get_test_examples(self, data_dir):
          return self._create_examples(
              self._read_file(os.path.join(data_dir, "test.tsv")), "test")

    """load pre examples"""
    def get_pre_examples(self, data_dir):
          return self._create_examples(
              self._read_file(os.path.join(data_dir, "pre_test.tsv")), "test")

    """set labels"""
    def get_labels(self):
        return ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']

    """read file"""
    def _read_file(self, input_file):
        with codecs.open(input_file, "r",encoding='utf-8') as f:
            lines = []
            for line in f.readlines():
                try:
                    line=line.split('\t')
                    assert len(line)==2
                    lines.append(line)
                except:
                    pass
            np.random.shuffle(lines)
            return lines

    """create examples for the data set """
    def _create_examples(self, lines, set_type):
        examples = []
        for (i, line) in enumerate(lines):
          guid = "%s-%s" % (set_type, i)
          text_a = tokenization.convert_to_unicode(line[1])
          label = tokenization.convert_to_unicode(line[0])
          examples.append(
              InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples

3.模型构建及训练

  • 然后打开控制台,运行python text_run.py train对模型进行训练。
  • 这里附上模型训练代码:
def train():
    """训练bert模型"""

    tensorboard_dir = os.path.join(config.output_dir, "tensorboard/textcnn")
    save_dir = os.path.join(config.output_dir, "checkpoints/textcnn")
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_path = os.path.join(save_dir, 'best_validation')

    start_time = time.time()

    tf.logging.info("*****************Loading training data*****************")
    train_examples = TextProcessor().get_train_examples(config.data_dir)
    trian_data = convert_examples_to_features(train_examples, label_list, config.seq_length, tokenizer)

    tf.logging.info("*****************Loading dev data*****************")
    dev_examples = TextProcessor().get_dev_examples(config.data_dir)
    dev_data = convert_examples_to_features(dev_examples, label_list, config.seq_length, tokenizer)

    tf.logging.info("Time cost: %.3f seconds...\n" % (time.time() - start_time))

    tf.logging.info("Building session and restore bert_model...\n")
    session = tf.Session()
    saver = tf.train.Saver()
    session.run(tf.global_variables_initializer())

    tf.summary.scalar("loss", model.loss)
    tf.summary.scalar("accuracy", model.acc)
    merged_summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter(tensorboard_dir)
    writer.add_graph(session.graph)
    optimistic_restore(session, config.init_checkpoint)

    tf.logging.info('Training and evaluating...\n')
    best_acc = 0
    last_improved = 0  # record global_step at best_val_accuracy
    flag = False

    for epoch in range(config.num_epochs):
        batch_train = batch_iter(trian_data, config.batch_size)
        start = time.time()
        tf.logging.info('Epoch:%d' % (epoch + 1))
        for batch_ids, batch_mask, batch_segment, batch_label in batch_train:
            feed_dict = feed_data(batch_ids, batch_mask, batch_segment, batch_label, config.keep_prob)
            _, global_step, train_summaries, train_loss, train_accuracy = session.run([model.optim, model.global_step,
                                                                                       merged_summary, model.loss,
                                                                                       model.acc], feed_dict=feed_dict)
            if global_step % config.print_per_batch == 0:
                end = time.time()
                val_loss, val_accuracy = evaluate(session, dev_data)
                merged_acc = (train_accuracy + val_accuracy) / 2
                if merged_acc > best_acc:
                    saver.save(session, save_path)
                    best_acc = merged_acc
                    last_improved = global_step
                    improved_str = '*'
                else:
                    improved_str = ''
                tf.logging.info(
                    "step: {},train loss: {:.3f}, train accuracy: {:.3f}, val loss: {:.3f}, val accuracy: {:.3f},training speed: {:.3f}sec/batch {}".format(
                        global_step, train_loss, train_accuracy, val_loss, val_accuracy,
                        (end - start) / config.print_per_batch, improved_str))
                start = time.time()

            if global_step - last_improved > config.require_improvement:
                tf.logging.info("No optimization over 1500 steps, stop training")
                flag = True
                break
        if flag:
            break
        config.lr *= config.lr_decay

4.模型测试

  • 训练完成后我们使用python text_run.py test对模型进行测试。
  • 这里附上模型测试代码:
def test():
    """testing"""

    save_dir = os.path.join(config.output_dir, "checkpoints/textcnn")
    save_path = os.path.join(save_dir, 'best_validation')

    if not os.path.exists(save_dir):
        tf.logging.info("maybe you don't train")
        exit()

    tf.logging.info("*****************Loading testing data*****************")
    test_examples = TextProcessor().get_test_examples(config.data_dir)
    test_data = convert_examples_to_features(test_examples, label_list, config.seq_length, tokenizer)

    input_ids, input_mask, segment_ids = [], [], []

    for features in test_data:
        input_ids.append(features['input_ids'])
        input_mask.append(features['input_mask'])
        segment_ids.append(features['segment_ids'])

    config.is_training = False
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess=session, save_path=save_path)

    tf.logging.info('Testing...')
    test_loss, test_accuracy = evaluate(session, test_data)
    msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
    tf.logging.info(msg.format(test_loss, test_accuracy))

    batch_size = config.batch_size
    data_len = len(test_data)
    num_batch = int((data_len - 1) / batch_size) + 1
    y_test_cls = [features['label_ids'] for features in test_data]
    y_pred_cls = np.zeros(shape=data_len, dtype=np.int32)

    for i in range(num_batch):
        start_id = i * batch_size
        end_id = min((i + 1) * batch_size, data_len)
        feed_dict = {
            model.input_ids: np.array(input_ids[start_id:end_id]),
            model.input_mask: np.array(input_mask[start_id:end_id]),
            model.segment_ids: np.array(segment_ids[start_id:end_id]),
            model.keep_prob: 1.0,
        }
        y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)

    '''
    输出测试矩阵
    '''
    # evaluate
    tf.logging.info("Precision, Recall and F1-Score...")
    tf.logging.info(metrics.classification_report(y_test_cls, y_pred_cls, target_names=label_list))

    tf.logging.info("Confusion Matrix...")
    cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    tf.logging.info(cm)
  • 模型分类效果测试情况如下,这里用sklearn的混淆矩阵输出:
  • 可以看到1w篇文章,10种类别测试准确率可以达到:92%,平均loss在:0.54,这个loss不算低,因为博主时间有限所以跑的epoch不多,有兴趣的同学可以继续跑,准确率至少应该可以达到 96% 以上。
    模型测试结果

5.Django展示界面构建

由于展示界面代码较多,这里就不一一进行展示,感兴趣的同学可以在文章下方找到完整代码下载地址。

  • 这里就附上加载较为关键的后端代码。
  • 加载并初始化模型:
def get_model():
    """
    模型初始化
    """
    g_config = TextConfig()
    save_dir = os.path.join(g_config.output_dir, "checkpoints/textcnn")
    save_path = os.path.join(save_dir, 'best_validation')

    g_start_time = time.time()
    tf.logging.set_verbosity(tf.logging.INFO)

    g_label_list = TextProcessor().get_labels()
    g_tokenizer = tokenization.FullTokenizer(vocab_file=g_config.vocab_file, do_lower_case=False)
    # 初始化模型
    g_model = TextCNN(g_config)
    g_end_time = time.time()
    g_config.is_training = False
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess=session, save_path=save_path)
    print("模型初始化时间:", g_end_time - g_start_time)
    return g_model, g_label_list, g_tokenizer, session
  • 输入文本结果预测:

def get_pre(final_model, label_list, tokenizer,session):
    """
    结果预测
    """
    config = TextConfig()
    save_dir = os.path.join(config.output_dir, "checkpoints/textcnn")
    save_path = os.path.join(save_dir, 'best_validation')

    if not os.path.exists(save_dir):
        tf.logging.info("训练路径模型不存在,请检查:‘result/checkpoints/textcnn/’,"
                        "路径下是否有保存模型:best_validation.data-00000-of-00001")
        exit()

    tf.logging.info("*****************读取预测文件*****************")
    test_examples = TextProcessor().get_pre_examples(config.data_dir)
    test_data = convert_examples_to_features(test_examples, label_list, config.seq_length, tokenizer)

    input_ids, input_mask, segment_ids = [], [], []

    for features in test_data:
        input_ids.append(features['input_ids'])
        input_mask.append(features['input_mask'])
        segment_ids.append(features['segment_ids'])

    # config.is_training = False
    # session = tf.Session()
    # session.run(tf.global_variables_initializer())
    # saver = tf.train.Saver()
    # saver.restore(sess=session, save_path=save_path)

    print('开始预测...')
    # test_loss, test_accuracy = evaluate(session, test_data)
    # msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
    # tf.logging.info(msg.format(test_loss, test_accuracy))

    batch_size = config.batch_size
    data_len = len(test_data)
    num_batch = int((data_len - 1) / batch_size) + 1
    y_test_cls = [features['label_ids'] for features in test_data]
    y_pred_cls = np.zeros(shape=data_len, dtype=np.int32)

    for i in range(num_batch):
        start_id = i * batch_size
        end_id = min((i + 1) * batch_size, data_len)
        feed_dict = {
            final_model.input_ids: np.array(input_ids[start_id:end_id]),
            final_model.input_mask: np.array(input_mask[start_id:end_id]),
            final_model.segment_ids: np.array(segment_ids[start_id:end_id]),
            final_model.keep_prob: 1.0,
        }
        y_pred_cls[start_id:end_id] = session.run(final_model.y_pred_cls, feed_dict=feed_dict)
    pre_label = y_pred_cls[0]
    print("预测index结果为:", pre_label)
    return pre_label

六、代码下载地址

由于项目代码量和数据集较大,感兴趣的同学可以直接下载代码,使用过程中如遇到任何问题可以在评论区进行评论,我都会一一解答。

代码下载:

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

手把手教你:基于Django的新闻文本分类可视化系统(文本分类由bert实现) 的相关文章

随机推荐

  • cocos creator入门教程实现简化版贪吃蛇

    开发工具 Cocos Creator和VS Code 开发语言 TS 简化版贪吃蛇的实现主要涉及的功能就是在吃到场景中随机产生产生的物体后 物体会到蛇头的后面并且跟随移动路径 其原理主要是通过数组来存储相关的坐标数据
  • hive-使用开窗函数实现百分比、topN、前百分比

    有一个订单表A 分别有order id 订单id user id 用户id amt 金额 三个字段 用sql实现以下功能 i 求订单总量为top3的用户及交易笔数 同时求出其交易笔数占全量订单笔数的占比 ii 求每个用户top3交易金额的订
  • SpringAOP JDK动态代理

    1 本篇博客的背景和目的 目前我本人正在学习SpringFramework的知识 这也是这个专栏的主题 我前面的几篇博文中 简单的认识了一下SpringFramework 记录了SpringFramework的环境搭建 记录了SpringI
  • 单片机 指针 的应用

    目录 直接访问物理地址下的数据 1 访问硬件指定内存下的数据 1 如设备ID号 2 将复杂格式的数据转换为字节 方便通信与存储 直接访问物理地址下的数据 1 访问硬件指定内存下的数据 1 如设备ID号 include
  • java 外部调用内部类的方法

    1 使用static可以声明一个内部类 可以直接在外部调用 class Outer 定义外部类 private static String info hello world 定义外部类的私有属性 static class Inner 使用s
  • 关于使用U盘安装ESXi发生的一些错误及解决经验

    烧录工具 rufus ESXi version 6 5U2 安装过程可以参考 https www starwindsoftware com blog create an esxi 6 5 installation usb under two
  • PyCaret入门

    安装 pip install pycaret 查看版本 from pycaret utils import version version 参考文档 GitHub 官网 用户教程 预处理 函数 模型 Notebook教程 函数 Functi
  • 解决VM Workstation安装VMware Tools显示灰色的办法

    其实虚拟机用了好多次了 但是每次使用配置时还是忘这忘那的 这里就简单地再啰嗦下了 解决办法如下 1 关闭虚拟机 2 在虚拟机设置分别设置CD DVD CD DVD2和软盘为自动检测三个步骤 3 再重启虚拟机 灰色字即点亮 如果上述步骤不行
  • 1.spark环境搭建

    Anaconda https www anaconda com products individual d JDK https docs aws amazon com zh cn corretto latest corretto 8 ug
  • OPENGL纹理加载显示颜色偏差

    问题 用Kinect Dk读出来的图像用Opencv显示没有纹理 保存为BMP也没有问题 但是OpenGL纹理加载显示出来偏蓝 解决 OpenGL纹理数据加载时使用的颜色通道错误了 原来数据的颜色通道是BGRA的 之前 glTexImage
  • mysql调优小计

    1 选择最合适的字段属性 类型 度 是否允许NULL等 尽量把字段设为not null 查询时对 是否为null 2 要尽量避免全表扫描 先应考虑在 where 及 order by 涉及的列上建 索引 3 应尽量避免在 where 句中对
  • 使用myheritage实现静态照片变成视频

    网址 https www myheritage com 首先 注册 可以使用google账号 其次 上传照片 接下来 生成动画 最后 下载视频
  • python之post上传文件简单示意

    coding utf 8 作者 萧海 联系 128 File py post py Date 9 1 2023 4 48 PM application
  • java.lang.NoClassDefFoundError: Could not initialize class com解决方案

    本文转载自 https www cnblogs com liuyangfirst p 6811937 html 作者 liuyangfirst 转载请注明该声明 编写的时候遇到这样一个bug java lang NoClassDefFoun
  • 如何使用WINDOWS7本地电脑的远程桌面连接阿里云WINDOWS服务器

    如果您的远程服务器采用了Windows服务器系统 那么使用WINDOWS7的 远程桌面连接 登录云服务器 无论在连接速度上还是方便度上 都会好很多 下面我介绍使用远程桌面连接的方法来管理云服务器 一 工具 原料 阿里云Windows SER
  • 低代码工具该如何选择?

    低代码 概念在国内持续走红 看到很多人都在问市面上这么多的低代码产品 应该如何选择 选择的标准到底是什么 这篇文章就和大家简略的分享一下三个检测 低代码 产品的标准 通过对这三方面的考量 相信大家都能擦亮眼睛找到最好的那一款 1 语言属性
  • 教你搞懂 Git!

    尽管每天你都会用到Git 常用的命令可能不到5个 但你可能现在还搞不懂它的工作原理 为什么Git可以管理版本 基本命令git add和git commit到底在干什么 在这篇文章中 我将用一个例子来解释Git的运行过程 帮助你理解Git的工
  • 派生类的定义

    类的继承与派生 基类与派生类 继承 inheritance 是面对对象程序设计的一个重要特性 是软件复用 software reuse 的一个重要形式 继承允许在原有类的基础上创建新的类 新类可以从一个或多个原有类中继承数据成员和成员函数
  • WRTnode-Windows的putty连接

    Putty是一款远程登录工具 用它可以非常方便的登录到Linux服务器上进行各种操作 命令行方式 Putty完全免费 而且无需安装 双击即可运行 支持多种连接类型 Telnet SSH Rlogin 使用简单 实在是一款十分值得推荐的工具
  • 手把手教你:基于Django的新闻文本分类可视化系统(文本分类由bert实现)

    系列文章 第十三章 手把手教你 基于python的文本分类 sklearn 决策树和随机森林实现 第十二章 手把手教你 岩石样本智能识别系统 第十一章 手把手教你 基于TensorFlow的语音识别系统 目录 系列文章 一 项目简介 二 任