Tensorflow学习(五)——多任务学习验证码识别实战

2023-11-20

一、验证码生成

"""
验证码生成脚本(使用captcha包提供的ImageCaptcha方法)
"""

from captcha.image import ImageCaptcha

import sys
import random
import numpy as np

"""
使用四位数字验证码,当然也可以加入大小写字母。四位验证码有10000种可能(0000~9999)
但是由于生成过程具有随机性,难免出现重复情况,所以最终生成的验证码数量少于10000
"""
number = np.arange(0, 10)
number = [str(x) for x in number]

def random_captcha_text(char_set=number, captcha_size=4):
    # 验证码列表
    captcha_text = []
    for i in range(captcha_size):
        c = random.choice(char_set)     # 随机选中构成名称
        captcha_text.append(c)          # 加入列表
    return captcha_text

def gen_captcha_text_and_image():
    image = ImageCaptcha()
    # 获得随机生成的验证码
    captcha_text = random_captcha_text()
    # 把验证码列表转为字符串
    captcha_text = ''.join(captcha_text)
    # 生成验证码
    captcha = image.generate(captcha_text)
    image.write(captcha_text, 'captcha/images/' + captcha_text + '.jpg')


num = 10000
for i in range(num):
    gen_captcha_text_and_image()
    sys.stdout.write('\r>> Creating image %d/%d' % (i+1, num))
    sys.stdout.flush()
sys.stdout.write('\n')
sys.stdout.flush()
print('生成完毕')

验证码存放在 "./captcha/images/’ 目录下,如图:
在这里插入图片描述
验证码图片如下:在这里插入图片描述
每张图片的label就是验证码数字,此图验证码数字为0695所以文件命名为0695.jpg

二、制作tfrecord文件

1、关于tfrecord文件:

TFRecords可以允许你讲任意的数据转换为TensorFlow所支持的格式, 这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords文件,TFRecords文件包含了[tf.train.Example 协议内存块(protocol buffer)](协议内存块包含了字段[Features],你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过[tf.python_io.TFRecordWriter class]写入到TFRecords文件。

TFRecords文件格式在图像识别中有很好的使用,其可以将二进制数据和标签数据(训练的类别标签)数据存储在同一个文件中,它可以在模型进行训练之前通过预处理步骤将图像转换为TFRecords格式,此格式最大的优点实践每幅输入图像和与之关联的标签放在同一个文件中.TFRecords文件是一种二进制文件,其不对数据进行压缩,所以可以被快速加载到内存中.格式不支持随机访问,因此它适合于大量的数据流,但不适用于快速分片或其他非连续存取。

TFrecord文件读写方式参考:https://zhuanlan.zhihu.com/p/31992460

2、代码

from PIL import Image
import tensorflow as tf
import numpy as np
import os
import random
import sys

_NUM_TEST = 500
_RANDOM_SEED = 0
DATASET_DIR = 'captcha/images'
TFRECORD_DIR = 'captcha/'


# 判断tfrecord文件是否存在
def _dataset_exists(dataset_dir):
    for split_name in ['train', 'test']:
        output_filename = os.path.join(dataset_dir, split_name + '.tfrecords')
        if not tf.gfile.Exists(output_filename):
            return False
    return True


def _get_filenames_and_classes(dataset_dir):
    photo_filenames = []
    for filename in os.listdir(dataset_dir):
        # 获取文件路径
        path = dataset_dir + '/' + filename
        photo_filenames.append(path)
    return photo_filenames


def bytes_feature(values):  # 格式转换(字符串)
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def int64_feature(values):  # 格式转换(64位int)
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def image_to_tfexample(image_date, label0, label1, label2, label3):
    # Abstract base class for protocol message
    return tf.train.Example(features=tf.train.Features(feature={
        'image': bytes_feature(image_date),
        'label0': int64_feature(label0),
        'label1': int64_feature(label1),
        'label2': int64_feature(label2),
        'label3': int64_feature(label3)
    }))


# 把数据转换成tfrecord格式
def _convert_dataset(split_name, filenames, dataset_dir):
    assert split_name in ['train', 'test']

    with tf.Session() as sess:
        # 定义tfrecord文件的路径和名称
        output_filename = os.path.join(TFRECORD_DIR, split_name + '.tfrecords')
        with tf.python_io.TFRecordWriter(output_filename, options=tf.python_io.TFRecordOptions(1)) as tfrecord_writer:
            for i, filename in enumerate(filenames):
                try:
                    sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, len(filenames)))
                    sys.stdout.flush()
                    # 读取图片
                    image_data = Image.open(filename)
                    # 根据模型的结构resize
                    image_data = image_data.resize((224, 224))
                    # 灰度转换
                    image_data = np.array(image_data.convert('L'))
                    # 将图片转换为二进制数据
                    image_data = image_data.tobytes()
                    # 获取label
                    labels = filename.split('/')[-1][0:4]
                    num_labels = []
                    for j in range(4):
                        num_labels.append(int(labels[j]))
                    # 生成protocol数据类型
                    example = image_to_tfexample(image_data, num_labels[0], num_labels[1],
                                                 num_labels[2], num_labels[3])
                    tfrecord_writer.write(example.SerializeToString())
                except IOError as e:
                    print('Could not read:', filenames[i])
                    print('Error:', e)
                    print('Skip it\n')
    sys.stdout.write('\n')
    sys.stdout.flush()


# 判断tfrecord文件是否存在
if _dataset_exists(TFRECORD_DIR):
    print('tfrecord文件已经存在')
else:
    # 获得所有图片
    photo_filenames = _get_filenames_and_classes(DATASET_DIR)
    # 把数据集分割为训练集和测试集并打乱
    random.seed(_RANDOM_SEED)
    random.shuffle(photo_filenames)
    training_filenames = photo_filenames[_NUM_TEST:]
    testing_filenames = photo_filenames[:_NUM_TEST]

    # 数据转换
    _convert_dataset('train', training_filenames, DATASET_DIR)
    _convert_dataset('test', training_filenames, DATASET_DIR)
    print('生成tfrecord文件')

说明:DATASET_DIR定义了数据集路径,TFRECORD_DIR定义了tfrecord文件存放路径,_NUM_TEST定义了test数据集数量,该程序将所有图片分为两部分,其中获得_NUM_TEST数量的图像作为测试数据集。在_convert_dataset()中我们对图像数据进行预处理包括灰度转换、图像大小转换已经二进制转换,这些操作方便了我们将数据写入文件以及训练时候对数据的使用。

最终生成的文件如下:
在这里插入图片描述

三、验证码识别模型训练

1、验证码识别思路

将验证码label拆分为4个

例如有一个验证码为0782,则拆分后的label如下(采用one-hot编码,对应位数值置1):

Label0:1000000000
Label1:0000000100
Label2:0000000010
Label3:0010000000

好处:可使用多任务学习

2、什么是多任务学习

在这里插入图片描述
其中X是输入,Shared Layer就是一些卷积与池化操作,Task1-4对应四个标签,产生四个loss,将四个loss求和得总的loss,用优化器优化总的loss,从而降低每个标签产生的loss。

3、获取谷歌提供的alexnet_v2网络

打开github,搜索 tensorflow/models,如下:
在这里插入图片描述
将models文件夹clone下来:
在这里插入图片描述
clone完成后,在路径 “/models/research/silm/” 下找到nets文件夹,将该文件夹拷贝到项目目录,我们在训练过程中会调用nets文件夹下提供的python代码(nets_factory.py)
在这里插入图片描述

4、修改alexnet.py代码

修改后代码如下:

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains a model definition for AlexNet.

This work was first described in:
  ImageNet Classification with Deep Convolutional Neural Networks
  Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton

and later refined in:
  One weird trick for parallelizing convolutional neural networks
  Alex Krizhevsky, 2014

Here we provide the implementation proposed in "One weird trick" and not
"ImageNet Classification", as per the paper, the LRN layers have been removed.

Usage:
  with slim.arg_scope(alexnet.alexnet_v2_arg_scope()):
    outputs, end_points = alexnet.alexnet_v2(inputs)

@@alexnet_v2
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.contrib import slim as contrib_slim

slim = contrib_slim

# pylint: disable=g-long-lambda
trunc_normal = lambda stddev: tf.compat.v1.truncated_normal_initializer(
    0.0, stddev)


def alexnet_v2_arg_scope(weight_decay=0.0005):
  with slim.arg_scope([slim.conv2d, slim.fully_connected],
                      activation_fn=tf.nn.relu,
                      biases_initializer=tf.compat.v1.constant_initializer(0.1),
                      weights_regularizer=slim.l2_regularizer(weight_decay)):
    with slim.arg_scope([slim.conv2d], padding='SAME'):
      with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
        return arg_sc


def alexnet_v2(inputs,
               num_classes=1000,
               is_training=True,
               dropout_keep_prob=0.5,
               spatial_squeeze=True,
               scope='alexnet_v2',
               global_pool=False):
  """AlexNet version 2.

  Described in: http://arxiv.org/pdf/1404.5997v2.pdf
  Parameters from:
  github.com/akrizhevsky/cuda-convnet2/blob/master/layers/
  layers-imagenet-1gpu.cfg

  Note: All the fully_connected layers have been transformed to conv2d layers.
        To use in classification mode, resize input to 224x224 or set
        global_pool=True. To use in fully convolutional mode, set
        spatial_squeeze to false.
        The LRN layers have been removed and change the initializers from
        random_normal_initializer to xavier_initializer.

  Args:
    inputs: a tensor of size [batch_size, height, width, channels].
    num_classes: the number of predicted classes. If 0 or None, the logits layer
    is omitted and the input features to the logits layer are returned instead.
    is_training: whether or not the model is being trained.
    dropout_keep_prob: the probability that activations are kept in the dropout
      layers during training.
    spatial_squeeze: whether or not should squeeze the spatial dimensions of the
      logits. Useful to remove unnecessary dimensions for classification.
    scope: Optional scope for the variables.
    global_pool: Optional boolean flag. If True, the input to the classification
      layer is avgpooled to size 1x1, for any input size. (This is not part
      of the original AlexNet.)

  Returns:
    net: the output of the logits layer (if num_classes is a non-zero integer),
      or the non-dropped-out input to the logits layer (if num_classes is 0
      or None).
    end_points: a dict of tensors with intermediate activations.
  """
  with tf.compat.v1.variable_scope(scope, 'alexnet_v2', [inputs]) as sc:
    end_points_collection = sc.original_name_scope + '_end_points'
    # Collect outputs for conv2d, fully_connected and max_pool2d.
    with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
                        outputs_collections=[end_points_collection]):
      net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
                        scope='conv1')
      net = slim.max_pool2d(net, [3, 3], 2, scope='pool1')
      net = slim.conv2d(net, 192, [5, 5], scope='conv2')
      net = slim.max_pool2d(net, [3, 3], 2, scope='pool2')
      net = slim.conv2d(net, 384, [3, 3], scope='conv3')
      net = slim.conv2d(net, 384, [3, 3], scope='conv4')
      net = slim.conv2d(net, 256, [3, 3], scope='conv5')
      net = slim.max_pool2d(net, [3, 3], 2, scope='pool5')

      # Use conv2d instead of fully_connected layers.
      with slim.arg_scope(
          [slim.conv2d],
          weights_initializer=trunc_normal(0.005),
          biases_initializer=tf.compat.v1.constant_initializer(0.1)):
        net = slim.conv2d(net, 4096, [5, 5], padding='VALID',
                          scope='fc6')
        net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
                           scope='dropout6')
        net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
        # Convert end_points_collection into a end_point dict.
        end_points = slim.utils.convert_collection_to_dict(
            end_points_collection)
        if global_pool:
          net = tf.reduce_mean(
              input_tensor=net, axis=[1, 2], keepdims=True, name='global_pool')
          end_points['global_pool'] = net
        if num_classes:
          net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
                             scope='dropout7')
          net0 = slim.conv2d(
              net,
              num_classes, [1, 1],
              activation_fn=None,
              normalizer_fn=None,
              biases_initializer=tf.compat.v1.zeros_initializer(),
              scope='fc8_0')
          net1 = slim.conv2d(
              net,
              num_classes, [1, 1],
              activation_fn=None,
              normalizer_fn=None,
              biases_initializer=tf.compat.v1.zeros_initializer(),
              scope='fc8_1')
          net2 = slim.conv2d(
              net,
              num_classes, [1, 1],
              activation_fn=None,
              normalizer_fn=None,
              biases_initializer=tf.compat.v1.zeros_initializer(),
              scope='fc8_2')
          net3 = slim.conv2d(
              net,
              num_classes, [1, 1],
              activation_fn=None,
              normalizer_fn=None,
              biases_initializer=tf.compat.v1.zeros_initializer(),
              scope='fc8_3')

          if spatial_squeeze:
            net0 = tf.squeeze(net0, [1, 2], name='fc8_0/squeezed')
          end_points[sc.name + '/fc8_0'] = net0
          if spatial_squeeze:
            net1 = tf.squeeze(net1, [1, 2], name='fc8_1/squeezed')
          end_points[sc.name + '/fc8_1'] = net1
          if spatial_squeeze:
            net2 = tf.squeeze(net2, [1, 2], name='fc8_2/squeezed')
          end_points[sc.name + '/fc8_2'] = net2
          if spatial_squeeze:
            net3 = tf.squeeze(net3, [1, 2], name='fc8_3/squeezed')
          end_points[sc.name + '/fc8_3'] = net3
      return net0, net1, net2, net3, end_points
alexnet_v2.default_image_size = 224

说明:网络中的卷积层和池化层不发生变化,原网络只有一个net输出,由于我们的验证码识别项目将验证码拆分成四个标签,所以需要四个输出,因此在源代码基础上增加net1 ~ net3输出。

5、train代码

"""验证码识别
学习模式:多任务学习
网络模型:alexnet_v2
完成时间:2020-5-1
"""

import tensorflow as tf
from nets import nets_factory


CHAR_SET_LEN = 10  # 不同字符数量
IMAGE_HEIGHT = 60  # 图片高度
IMAGE_WIDTH = 160  # 图片宽度
BATCH_SIZE = 25
TFRECORD_FILE = 'D:/PycharmProject/StudyDemo/captcha/train.tfrecords'

# placeholder
x = tf.placeholder(tf.float32, [None, 224, 224])
y0 = tf.placeholder(tf.float32, [None])
y1 = tf.placeholder(tf.float32, [None])
y2 = tf.placeholder(tf.float32, [None])
y3 = tf.placeholder(tf.float32, [None])

_learn_rate = tf.Variable(0.003, dtype=tf.float32)


# 从tfrecord文件中读出数据
def read_and_decode(filename):
    # 生成文件队列
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader(options=tf.python_io.TFRecordOptions(1))
    # 返回文件名和文件
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, features={
        'image': tf.FixedLenFeature([], tf.string),
        'label0': tf.FixedLenFeature([], tf.int64),
        'label1': tf.FixedLenFeature([], tf.int64),
        'label2': tf.FixedLenFeature([], tf.int64),
        'label3': tf.FixedLenFeature([], tf.int64),
    })
    # 获取图片数据
    image = tf.decode_raw(features['image'], tf.uint8)
    # tf.train.shuffle_batch的使用必须确定shape
    image = tf.reshape(image, [224, 224])
    # 图片预处理
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.subtract(image, 0.5)
    image = tf.multiply(image, 2.0)
    # 获取label
    label0 = tf.cast(features['label0'], tf.int32)
    label1 = tf.cast(features['label1'], tf.int32)
    label2 = tf.cast(features['label2'], tf.int32)
    label3 = tf.cast(features['label3'], tf.int32)

    return image, label0, label1, label2, label3


# 获取图片数据与标签
image, label0, label1, label2, label3 = read_and_decode(TFRECORD_FILE)
# 使用shuffle_batch随机打乱张量顺序创建批次
image_batch, label_batch0, label_batch1, label_batch2, label_batch3 = tf.train.shuffle_batch(
    [image, label0, label1, label2, label3], batch_size=BATCH_SIZE,
    capacity=50000, min_after_dequeue=10000, num_threads=1
)

# 定义网络结构
train_network_fn = nets_factory.get_network_fn('alexnet_v2',
                                               num_classes=CHAR_SET_LEN,
                                               weight_decay=0.0005,
                                               is_training=True)
with tf.Session() as sess:
    # input参数要符合Alexnet_v2网络的要求,所以先做个格式转换
    X = tf.reshape(x, [BATCH_SIZE, 224, 224, 1])
    # 数据输入网络得到输出值
    logits0, logits1, logits2, logits3, _ = train_network_fn(X)

    # 把标签转换成one_hot形式
    one_hot_labels0 = tf.one_hot(indices=tf.cast(y0, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels1 = tf.one_hot(indices=tf.cast(y1, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels2 = tf.one_hot(indices=tf.cast(y2, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels3 = tf.one_hot(indices=tf.cast(y3, tf.int32), depth=CHAR_SET_LEN)

    # 计算损失值
    loss0 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits0,
                                                                   labels=one_hot_labels0))
    loss1 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits1,
                                                                   labels=one_hot_labels1))
    loss2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits2,
                                                                   labels=one_hot_labels2))
    loss3 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits3,
                                                                   labels=one_hot_labels3))
    # 总和损失值
    total_loss = (loss0 + loss1 + loss2 + loss3) / 4.0
    # 优化器
    optimizer = tf.train.AdamOptimizer(learning_rate=_learn_rate).minimize(total_loss)
    # 计算准确率
    correct_prediction0 = tf.equal(tf.argmax(one_hot_labels0, 1), tf.argmax(logits0, 1))
    accuracy0 = tf.reduce_mean(tf.cast(correct_prediction0, tf.float32))
    correct_prediction1 = tf.equal(tf.argmax(one_hot_labels1, 1), tf.argmax(logits1, 1))
    accuracy1 = tf.reduce_mean(tf.cast(correct_prediction1, tf.float32))
    correct_prediction2 = tf.equal(tf.argmax(one_hot_labels2, 1), tf.argmax(logits2, 1))
    accuracy2 = tf.reduce_mean(tf.cast(correct_prediction2, tf.float32))
    correct_prediction3 = tf.equal(tf.argmax(one_hot_labels3, 1), tf.argmax(logits3, 1))
    accuracy3 = tf.reduce_mean(tf.cast(correct_prediction3, tf.float32))

    # 保存模型
    saver = tf.train.Saver()
    # 初始化变量
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    # 创建一个协调器管理线程
    coord = tf.train.Coordinator()
    # 启动QueueRunner,此时文件名队列已经进队
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(6001):
        # 获得一个批次的数据和标签
        b_image, b_label0, b_label1, b_label2, b_label3 = sess.run([image_batch,
                                                                    label_batch0,
                                                                    label_batch1,
                                                                    label_batch2,
                                                                    label_batch3])
        # 优化模型
        sess.run(optimizer, feed_dict={
            x: b_image,
            y0: b_label0,
            y1: b_label1,
            y2: b_label2,
            y3: b_label3
        })
        # 每迭代50次计算并打印一次损失值和准确率
        if i % 50 == 0:
            # 每2000次降低学习率
            if i % 2000 == 0:
                sess.run(tf.assign(_learn_rate, _learn_rate / 3))
            acc0, acc1, acc2, acc3, loss_ = sess.run([accuracy0, accuracy1, accuracy2, accuracy3, total_loss],
                                                     feed_dict={
                                                         x: b_image,
                                                         y0: b_label0,
                                                         y1: b_label1,
                                                         y2: b_label2,
                                                         y3: b_label3
                                                     })
            learing_rate = sess.run(_learn_rate)
            print('Iter: %d  loss: %.3f  accuracy:%.2f,%.2f,%.2f,%.2f  learing_rate:%.4f'
                  % (i, loss_, acc0, acc1, acc2, acc3, learing_rate))
            # 停止训练 / 保存模型
            if i == 6000:   # global_step参数是把训练次数添加到模型名称中
                saver.save(sess, './captcha/models/crack_captcha.model', global_step=i)
                break
    coord.request_stop()    # 通知其他线程关闭
    coord.join(threads)     # 其他线程关闭后该函数才可返回

代码概述:从train.tfrecord读出数据和标签,打乱,将数据送入alexnet网络得到输出值,将输出的标签转化为one_hot形式,计算loss,对loss求和得total_loss并用优化器优化,计算准确率,训练6000次,保存模型。
注意:tfrecords文件读写前后数据格式一定要对应,TFRecordWriter和TFRecordReader的options一定要相同,不然容易出现读写错误,需仔细检查。

保存的模型如下:
在这里插入图片描述
提示:训练过程较慢,笔者使用NVIDIA 940mx显卡跑满2G显存总共花费13个小时完成训练,最终准确率达到99%。

四、模型测试

代码与训练代码相似:

import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from nets import nets_factory

# 不同字符数量
CHAR_SET_LEN = 10
# 图片高度和宽度
IMAGE_HEIGHT = 60
IMAGE_WIDTH = 160
# 批次
BATCH_SIZE = 1
# tfrecord文件存放路径
TFRECORD_FILE = 'captcha/test.tfrecords'

# placeholder
x = tf.placeholder(tf.float32, [None, 224, 224])


# 从tfrecord读出数据
def read_and_decode(filename):
    # 生成文件队列
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader(options=tf.python_io.TFRecordOptions(1))
    # 返回文件名和文件
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, features={
        'image': tf.FixedLenFeature([], tf.string),
        'label0': tf.FixedLenFeature([], tf.int64),
        'label1': tf.FixedLenFeature([], tf.int64),
        'label2': tf.FixedLenFeature([], tf.int64),
        'label3': tf.FixedLenFeature([], tf.int64),
    })
    # 获取图片数据
    image = tf.decode_raw(features['image'], tf.uint8)
    # 没有经过预处理的灰度图
    image_raw = tf.reshape(image, [224, 224])
    # tf.train.shuffle_batch的使用必须确定shape
    image = tf.reshape(image, [224, 224])
    # 图片预处理
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.subtract(image, 0.5)
    image = tf.multiply(image, 2.0)
    # 获取label
    label0 = tf.cast(features['label0'], tf.int32)
    label1 = tf.cast(features['label1'], tf.int32)
    label2 = tf.cast(features['label2'], tf.int32)
    label3 = tf.cast(features['label3'], tf.int32)

    return image, image_raw, label0, label1, label2, label3


# 获取图片数据与标签
image, image_raw, label0, label1, label2, label3 = read_and_decode(TFRECORD_FILE)
# 获得批次
image_batch, image_raw_batch, label_batch0, label_batch1, label_batch2, label_batch3 = tf.train.shuffle_batch(
    [image, image_raw, label0, label1, label2, label3], batch_size=BATCH_SIZE,
    capacity=50000, min_after_dequeue=10000, num_threads=1
)

# 定义网络结构
train_network_fn = nets_factory.get_network_fn('alexnet_v2',
                                               num_classes=CHAR_SET_LEN,
                                               weight_decay=0.0005,
                                               is_training=False)
with tf.Session() as sess:
    # inputs格式[batch_size, height, width, channels]
    X = tf.reshape(x, [BATCH_SIZE, 224, 224, 1])
    # 数据输入网络得到输出值
    logits0, logits1, logits2, logits3, _ = train_network_fn(X)
    # 预测值
    predict0 = tf.reshape(logits0, [-1, CHAR_SET_LEN])
    predict0 = tf.argmax(predict0, 1)

    predict1 = tf.reshape(logits1, [-1, CHAR_SET_LEN])
    predict1 = tf.argmax(predict1, 1)

    predict2 = tf.reshape(logits2, [-1, CHAR_SET_LEN])
    predict2 = tf.argmax(predict2, 1)

    predict3 = tf.reshape(logits3, [-1, CHAR_SET_LEN])
    predict3 = tf.argmax(predict3, 1)

    # 初始化
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    # 载入模型
    saver = tf.train.Saver()
    saver.restore(sess, './captcha/models/crack_captcha.model-6000')
    # 创建一个协调器管理线程
    coord = tf.train.Coordinator()
    # 启动QueueRunner,此时文件名队列已经进队
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(10):
        # 获得一个批次的数据和标签
        b_image, b_image_raw, b_label0, b_label1, b_label2, b_label3 = sess.run([image_batch,
                                                                                 image_raw_batch,
                                                                                 label_batch0,
                                                                                 label_batch1,
                                                                                 label_batch2,
                                                                                 label_batch3])
        # 显示图片
        img = Image.fromarray(b_image_raw[0], 'L')
        plt.imshow(img)
        plt.axis('off')
        plt.show()
        # 打印标签
        print('label:', b_label0, b_label1, b_label2, b_label3)
        # 预测
        label0, label1, label2, label3 = sess.run([predict0, predict1, predict2, predict3],
                                                  feed_dict={x: b_image})
        # 打印预测值
        print('predict:', label0, label1, label2, label3)

    # 通知其他线程关闭
    coord.request_stop()
    coord.join(threads)

运行结果:
在这里插入图片描述
在这里插入图片描述

END

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow学习(五)——多任务学习验证码识别实战 的相关文章

  • 如何更改默认的Python版本?

    我已经在我的 Mac 上安装了 Python 3 2 我跑完之后 Applications Python 3 2 Update Shell Profile command 当我输入时 这很令人困惑Python V在终端它说Python 2
  • Keras ZeroDivisionError:整数除法或以零为模

    我正在尝试使用 Keras 和 Tensorflow 实现卷积神经网络 我有以下代码 from keras models import Sequential from keras layers import Conv2D MaxPoolin
  • 分配列表的多个值

    我很想知道是否有一种 Pythonic 方式将列表中的值分配给元素 为了更清楚 我要求这样的事情 myList 3 5 7 2 a b c d something myList So that a 3 b 5 c 7 d 2 我正在寻找比手
  • 高效地将大型 Pandas 数据帧写入磁盘

    我正在尝试找到使用 Python Pandas 高效地将大型数据帧 250MB 写入磁盘或从磁盘写入的最佳方法 我已经尝试了所有方法Python 数据分析 但表现却非常令人失望 这是一个更大项目的一部分 该项目探索将我们当前的分析 数据管理
  • minAreaRect OpenCV 返回的裁剪矩形 [Python]

    minAreaRectOpenCV 中返回一个旋转的矩形 如何裁剪矩形内图像的这部分 boxPoints返回旋转矩形的角点的坐标 以便可以通过循环框内的点来访问像素 但是在 Python 中是否有更快的裁剪方法 EDIT See code在
  • Django 查询:“datetime + delta”作为表达式

    好吧 我的问题如下 假设我有下一个模型 这是一个简单的情况 class Period models Model name CharField field specs here start date DateTimeField field s
  • 如何将脚本作为 pytest 测试运行

    假设我有一个用简单脚本表示的测试assert 陈述 请参阅背景了解原因 例如 import foo assert foo 3 4 我如何以一种好的方式将该脚本包含在我的 pytest 测试套件中 我尝试了两种有效但不太好的方法 一种方法是将
  • 如何在动态执行的代码字符串中使用inspect.getsource?

    如果我在文件中有这段代码 import inspect def sample p1 print p1 return 1 print inspect getsource sample 当我运行脚本时 它按预期工作 在最后一行 源代码sampl
  • 样本()和r样本()有什么区别?

    当我从 PyTorch 中的发行版中采样时 两者sample and rsample似乎给出了类似的结果 import torch seaborn as sns x torch distributions Normal torch tens
  • Django Web 应用程序中的 SMTP 问题

    我被要求向使用 Django Python 框架实现的现有程序添加一个功能 此功能将允许用户单击一个按钮 该按钮将显示一个小对话框 表单以输入值 我确实编写了一些代码 显示电子邮件已发送的消息 但实际上 它没有发送 My code from
  • 使用 Windows 任务计划程序安排 [Virtualenv 相关] Python 脚本

    I want to schedule a python script to start at 3AM and break at 5PM every weekday However the problem arises when I need
  • django 中的身份验证方法返回 None

    你好 我在 django 中做了一个简单的注册和登录页面 当想要登录时 登录视图中的身份验证方法不返回任何内容 我的身份验证应用程序 模型 py from django db import models from django contri
  • 如何让 Streamlit 每 5 秒重新加载一次?

    我必须每 5 秒重新加载 Streamlit 图表 以便在 XLSX 报告中可视化新数据 如何实现这一目标 import streamlit as st import pandas as pd import os mainDir os pa
  • 导入目录下的所有模块

    有没有办法导入当前目录中的所有模块 并返回它们的列表 例如 对于包含以下内容的目录 mod py mod2 py mod3 py 它会给你
  • 将 Python Selenium 输出写入 Excel

    我编写了一个脚本来从在线网站上抓取产品信息 目标是将这些信息写入 Excel 文件 由于我的Python知识有限 我只知道如何在Powershell中使用Out file导出 但结果是每个产品的信息都打印在不同的行上 我希望每种产品都有一条
  • 我可以在 if 语句中使用“as”机制吗

    是否可以使用as in if类似的声明with我们使用的 例如 with open tmp foo r as ofile do something with ofile 这是我的代码 def my list rtrn lst True if
  • Python RE(总之检查第一个字母是否区分大小写,其余部分不区分大小写)

    在下面的情况下 我想匹配字符串 Singapore 其中 S 应始终为大写 其余单词可能为小写或大写 但在下面的字符串 s 是小写的 它在搜索条件中匹配 任何人都可以让我知道如何实施吗 import re st Information in
  • 旧版本的 spaCy 在尝试安装模型时抛出“KeyError: 'package'”错误

    我在 Ubuntu 14 04 4 LTS x64 上使用 spaCy 1 6 0 和 python3 5 为了安装 spaCy 的英文版本 我尝试运行 这给了我错误消息 ubun ner 3 NeuroNER master src pyt
  • 如何在supervisord中设置组?

    因此 我正在设置 Supervisord 并尝试控制多个进程 并且一切正常 现在我想设置一个组 以便我可以启动 停止不同的进程集 而不是全部或全无 这是我的配置文件的片段 group tapjoy programs tapjoy game1
  • python 日志记录替代方案 [关闭]

    Closed 此问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 蟒蛇记录模块 http docs python org library logging html使用起来

随机推荐

  • python——selenium

    一 Selenium Python环境搭建及配置 1 1 selenium 介绍 selenium 是一个 web 的自动化测试工具 不少学习功能自动化的同学开始首选 selenium 因为它相比 QTP 有诸多有点 免费 也不用再为破解
  • cpolar内网穿透+ EasyImage组合,自建一个图床网站

    文章目录 1 前言 2 EasyImage网站搭建 2 1 EasyImage下载和安装 2 2 EasyImage网页测试 2 3 cpolar的安装和注册 3 本地网页发布 3 1 Cpolar云端设置 3 2 Cpolar内网穿透本地
  • 【马士兵】Python基础--15

    Python基础 15 文章目录 Python基础 15 编程思想 类与对象 类的创建 对象的创建 类属性 类方法 静态方法 动态绑定属性和方法 知识点总结 编程思想 类与对象 python中一切皆对象 类的创建 类的名称由一个或多个单词组
  • 【SpringCloud】SpringAMQP总结

    文章目录 1 AMQP 2 基本消息模型队列 3 WorkQueue模型 4 发布订阅模型 5 发布订阅 Fanout Exchange 6 发布订阅 DirectExchange 7 发布订阅 TopicExchange 8 消息转换器
  • 迁移学习 & 凯明初始化

    前言 这一章其实就是之前没做完的事 来补一下 两者其实没啥关系 迁移学习 以下内容学习自迁移学习 斯坦福21秋季 实用机器学习中文版 迁移学习包括什么 feature extraction train a model on a relate
  • 由于缺少调试目标 E:a\b\c\串口配置工具\bin\Debug\串口配置工具.exe“,visual Studio无法开始调试。请生成项目并重试,或者相应OutputPath和AssemblyNa

    最近做一个窗体程序时候出现这个错误 我的项目名称是串口配置工具 建议为英文来命名 项目名称下面有这两个 发现 没有这个串口配置工具 exe 然后再这个 这里面发现这个串口配置工具 exe 最后直接 exe文件把这个复制到 项目名称 bin
  • C++基础——const成员函数

    目录 一 Const成员函数 1 定义 2 格式 3 代码示例 h文件 definition cpp文件 特性 例 那么const对象既可以调用非const型成员函数吗 问题3 const成员函数内可以调用其它的非const成员函数吗 问题
  • 手机运行python 神器,pydroid3 包含库的版本

    初次安装pydroid 或者qpython的同学运行爬虫时是不是蛋疼的一比 lxml根本装不了 虽然可以下载whl折腾 可是也很麻烦 后来我不死心 终于找到了包含库的版本 只有pydroid 64位 https lanzous com id
  • msa2000映射到服务器,HPmsa2000i官方详细的设置操作流程步骤.doc

    HPmsa2000i官方详细的设置操作流程步骤 从本地管理主机登录进入 SMU 如要从本地管理主机登录进入 SMU 在网络浏览器的地址栏中 键入某个控制器机柜的以太网管理端口的 IP 地址 然后按Enter 此时显示 SMU Login 页
  • IDEA java.lang.NullPointerException (no error message)

    今天在不停启动debug 停止debug后无法再启动debug 提示java lang NullPointerException no error message 经百度 删除 project下 gradle无效 恢复代码后无效 且未更改配
  • 【C语言】合并两个数组,降序排列并删除重复元素(通俗易懂)

    问题描述 试着写一个程序 具体内容如下 建立两个整型数组 int n scanf d n int a n 将其合并 对他们进行降序排序 去掉相同项 输出处理过后的数组 输入形式 首先第一行输入第一个数组中的长度n 然后输入n个整型数 然后在
  • MYSQL进阶-msql日志-慢查询日志

    2 慢查询日志 慢查询日志主要用来记录执行时间超过设置的某个时长的SQL语句 能够帮助数据库维护人员找出执行时间比较长 执行效率比较低的SQL语句 并对这些SQL语句进行针对性优化 2 1 开启慢查询日志 可以在my cnf文件或者my i
  • ant design pro 代码学习(七) ----- 组件封装(登录模块)

    以登录模块为例 对ant design pro的组件封装进行相关分析 登录模块包含基础组件的封装 组件按模块划分 同类组件通过配置文件生成 跨层级组件直接数据通信等 相对来说还是具有一定的代表性 1 登录模块流程图 首先 全局了解一下登录模
  • 在idea中安装并且使用easy code插件 ,以及在idea中配置mysql数据库

    在idea中安装并且使用easy code插件 以及在idea中配置mysql数据库 1 从导航栏进入设置页面 2 点击plugins选项 在输入框中输入easy code查找 并点击installed安装 下载安装好了以后需要重启软件 点
  • GNURadio报错Unable to create context(windows10环境)

    GNURadio报错Unable to create context windows10环境 这里本人使用的是GNU Radio3 7 11 iiosupport win64 版本 外设是ADI的ADALM PLUTO 这里本人使用的是GN
  • 多维时序

    多维时序 MATLAB实现ELM极限学习机多维时序预测 股票价格预测 目录 多维时序 MATLAB实现ELM极限学习机多维时序预测 股票价格预测 效果一览 基本介绍 程序设计 结果输出 参考资料 效果一览 基本介绍
  • 2018-12-13 LeetCode Q5 最长回文子串

    5 最长回文子串 给定一个字符串 s 找到 s 中最长的回文子串 你可以假设 s 的最大长度为 1000 示例 1 输入 babad 输出 bab 注意 aba 也是一个有效答案 示例 2 输入 cbbd 输出 bb 暴力解法 6004ms
  • 关于linux进程间的close-on-exec机制

    转载请注明出处 帘卷西风的专栏 http blog csdn net ljxfblog 前几天写了一篇博客 讲述了端口占用情况的查看和解决 关于linux系统端口查看和占用的解决方案 大部分这种问题都能够解决 在文章的最后 提到了一种特殊情
  • 判断字符串的两半是否相似(1704.leetcode)-------------------c++实现

    判断字符串的两半是否相似 1704 leetcode unordered map c 实现 题目表述 给你一个偶数长度的字符串 s 将其拆分成长度相同的两半 前一半为 a 后一半为 b 两个字符串 相似 的前提是它们都含有相同数目的元音 a
  • Tensorflow学习(五)——多任务学习验证码识别实战

    一 验证码生成 验证码生成脚本 使用captcha包提供的ImageCaptcha方法 from captcha image import ImageCaptcha import sys import random import numpy