cnn手写汉字识别

2023-10-29

import os
import numpy as np
import struct
import PIL.Image
import cv2
import scipy.misc
from sklearn.utils import shuffle
import tensorflow as tf
from pylab import *

tf.app.flags.DEFINE_string("checkpoint", "ckpt/", "dir of checkpoint")
tf.app.flags.DEFINE_bool("restore", False, "restore from previous checkpoint")

FLAGS = tf.app.flags.FLAGS

# train_data_dir = "F:\HandWritingDatabases\HWDB1.1trn_gnt"
# test_data_dir = "F:\HandWritingDatabases\HWDB1.1tst_gnt"

#
train_data_dir = "../trn_gnt"
test_data_dir = "../tst_gnt"

# 取常用的100个汉字进行测试
char_set = "的一是了我不人在他有这个上们来到时大地为子中你说生国年着就那和要她出也得里后自以会家可下而过天去能对小多然于心学么之都好看起发当没成只如事把还用第样道想作种开美总从无情己面最女但现前些所同日手又行意动"
print(len(char_set))


# 从gnt文件中读取图像和对应的汉字
def read_from_gnt_dir(gnt_dir=train_data_dir):
    def one_file(f):
        header_size = 10
        while True:
            header = np.fromfile(f, dtype='uint8', count=header_size)
            if not header.size: break
            sample_size = header[0] + (header[1] << 8) + (header[2] << 16) + (header[3] << 24)
            tagcode = header[5] + (header[4] << 8)
            width = header[6] + (header[7] << 8)
            height = header[8] + (header[9] << 8)
            if header_size + width * height != sample_size:
                break
            image = np.fromfile(f, dtype='uint8', count=width * height).reshape((height, width))
            yield image, tagcode

    for file_name in os.listdir(gnt_dir):
        # print(file_name)
        if file_name.endswith('.gnt'):
            file_path = os.path.join(gnt_dir, file_name)
            # print(file_path)
            with open(file_path, 'rb') as f:
                for image, tagcode in one_file(f):
                    yield image, tagcode


 # 统计样本数和提取一点图像
def extractImge():
    # 统计样本数
    train_counter = 0
    test_counter = 0
    for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):
        tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')

        # 提取点图像
        if train_counter < 1000:
            im = PIL.Image.fromarray(image)
            im.convert('RGB').save('images/' + tagcode_unicode + str(train_counter) + '.png')
        else:
            break
        train_counter += 1

    for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir):
        tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
        test_counter += 1
    # 样本数
    print(train_counter, test_counter)


def resize_and_normalize_image(img):
    # 补方
    pad_size = abs(img.shape[0] - img.shape[1]) // 2
    if img.shape[0] < img.shape[1]:
        pad_dims = ((pad_size, pad_size), (0, 0))
    else:
        pad_dims = ((0, 0), (pad_size, pad_size))
    img = np.lib.pad(img, pad_dims, mode='constant', constant_values=255)
    # 缩放
    img = scipy.misc.imresize(img, (64 - 4 * 2, 64 - 4 * 2))
    img = np.lib.pad(img, ((4, 4), (4, 4)), mode='constant', constant_values=255)
    # assert img.shape == (64, 64)

    img = img.flatten()  # 降到一维
    # 像素值范围-1到1
    img = (img - 128) / 128
    return img


# one hot
def convert_to_one_hot(char):
    vector = np.zeros(len(char_set))
    vector[char_set.index(char)] = 1
    return vector


# 由于数据量不大, 可一次全部加载到RAM
train_data_x = []  # (m,4096)
train_data_y = []  # (m,100) [1,0,0] one-hot表示
train_data_count = 0
batch_size = 64  # 每次训练的图像数量 TODO: 改为128看看
num_batch = 0


def preProcessImg(image):

    # 裁剪图片 中心裁剪图片
    # image1 = cv2.imread(path, cv2.IMREAD_GRAYSCALE)  # 灰度化
    # 灰度化处理
    if len(image.shape) == 3 or len(image.shape) == 4 :
        image1 = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    else:
        image1 = image

    # ret,thresh = cv2.threshold(image,127,255,cv2.THRESH_BINARY) #简单阈值二值化
    # ret, thresh = cv2.threshold(image, 110, 255, cv2.THRESH_BINARY)  # 简单阈值二值化
    ret, image2 = cv2.threshold(image1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)  # Otsu’s二值化
    # print("ret= ", ret)
    image3 = cv2.resize(image2,(64,64))

    image = 1 * (image3.flatten())
    image = np.asarray(image) / 255.0

    return image



def load_train_data():
    global train_data_x
    global train_data_y
    global num_batch
    global train_data_count
    for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):
        tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
        if tagcode_unicode in char_set:
            print(tagcode_unicode)
            train_data_count += 1
            # image = preProcessImg(image)
            # train_data_x.append(image)
            train_data_x.append(resize_and_normalize_image(image))

            train_data_y.append(convert_to_one_hot(tagcode_unicode))

    # 33505
    print(np.shape(train_data_x))
    print(np.shape(train_data_y))

    # train_data_x, train_data_y = shuffle(train_data_x, train_data_y, random_state=0)
    # TODO TypeError: shuffle() takes no keyword arguments

    num_batch = len(train_data_x) // batch_size  # 向下取整
    print("num_batch=", num_batch)


#TODO 这里需要修改
def shuffleData():
    global train_data_x
    global train_data_y
    train_data_x, train_data_y = shuffle(train_data_x, train_data_y, random_state=0)


test_data_x = []  # 测试数据
test_data_y = []
test_data_count = 0

#TODO 修改,直接从提取好的图片文件夹中读取
def load_test_data():
    global test_data_x  # 测试数据
    global test_data_y
    global test_data_count
    for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir):
        tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
        if tagcode_unicode in char_set:
            test_data_count += 1
            # image = preProcessImg(image)
            # test_data_x.append(image)
            test_data_x.append(resize_and_normalize_image(image))

            test_data_y.append(convert_to_one_hot(tagcode_unicode))
    # shuffle样本
    # test_data_x, test_data_y = shuffle(test_data_x, test_data_y, random_state=0)
    print(np.shape(test_data_x))
    print(np.shape(test_data_y))


X = tf.placeholder(tf.float32, [None, 64 * 64])
Y = tf.placeholder(tf.float32, [None, 100])
keep_prob = tf.placeholder(tf.float32)


def chinese_hand_write_cnn():
    x = tf.reshape(X, shape=[-1, 64, 64, 1])
    # 3 conv layers
    w_c1 = tf.Variable(tf.random_normal([3, 3, 1, 32], stddev=0.01))
    b_c1 = tf.Variable(tf.zeros([32]))
    conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1))
    conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

    w_c2 = tf.Variable(tf.random_normal([3, 3, 32, 64], stddev=0.01))
    b_c2 = tf.Variable(tf.zeros([64]))
    conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, w_c2, strides=[1, 1, 1, 1], padding='SAME'), b_c2))
    conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

    """

    w_c3 = tf.Variable(tf.random_normal([3, 3, 64, 128], stddev=0.01))
    b_c3 = tf.Variable(tf.zeros([128]))
    conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, w_c3, strides=[1, 1, 1, 1], padding='SAME'), b_c3))
    conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv3 = tf.nn.dropout(conv3, keep_prob)
    """

    # fully connect layer
    w_d = tf.Variable(tf.random_normal([16 * 16 * 64, 1024], stddev=0.01))
    b_d = tf.Variable(tf.zeros([1024]))
    dense = tf.reshape(conv2, [-1, w_d.get_shape().as_list()[0]])
    dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))
    dense = tf.nn.dropout(dense, keep_prob)

    w_out = tf.Variable(tf.random_normal([1024, 100], stddev=0.01))
    b_out = tf.Variable(tf.zeros([100]))
    # out = tf.add(tf.matmul(dense, w_out), b_out)
    out = tf.nn.softmax(tf.add(tf.matmul(dense, w_out), b_out))

    return out


lable_size = 100 #100个汉字
input_size = 64 * 64
batch_size = 64 #TODO 改成128看看?
hidden_size = 1024

# bp神经网络
def bp_nn():
    #输入层
    w1 = tf.Variable(tf.random_normal([input_size,hidden_size],stddev=0.1))
    b1 = tf.Variable(tf.constant(0.1),[hidden_size])
    #隐含层
    hidden = tf.matmul(X,w1)+b1
    hidden = tf.nn.relu(hidden)

    w2 = tf.Variable(tf.random_normal([hidden_size,lable_size],stddev=0.1))
    b2 = tf.Variable(tf.constant(0.1), [lable_size])

    #输出层
    output = tf.matmul(hidden,w2) + b2
    output = tf.nn.relu(output)
    # output = tf.nn.softmax(output)

    return  output



def train_hand_write_nn():
    output = chinese_hand_write_cnn()
    # output = bp_nn()

    loss = -tf.reduce_sum(Y * tf.log(tf.clip_by_value(output, 1e-15, 1.0)))  # loss=nan的情况,梯度爆炸?
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)  # 学习率0.001 还是0.0001 TODO 改变一下学习率比较一下
    # 学习率设置为0.001时候出现loss=nan 错误

    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(output, 1), tf.argmax(Y, 1)), tf.float32))

    # TensorBoard 可视化
    tf.summary.scalar("loss", loss)
    tf.summary.scalar("accuracy", accuracy)
    merged_summary_op = tf.summary.merge_all()

    saver = tf.train.Saver(max_to_keep=1)  # 只保存最新的模型
    max_acc = 0  # TODO: 将max_acc写到文件中
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        step = 0

        if FLAGS.restore:
            # Get last checkpoint in checkpoint directory
            checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint)
            if checkpoint:
                # Restore data from checkpoint
                saver.restore(sess, checkpoint)
                step += int(checkpoint.split('-')[-1])
                print("step=",step)
                print("Train from checkpoint")

        # 命令行执行 tensorboard --logdir=./log  打开浏览器访问http://0.0.0.0:6006
        summary_writer = tf.summary.FileWriter('./log', sess.graph)

        for e in range(50):
            for i in range(num_batch):
                batch_x = train_data_x[i * batch_size: (i + 1) * batch_size]
                batch_y = train_data_y[i * batch_size: (i + 1) * batch_size]
                _, loss_, summary = sess.run([optimizer, loss, merged_summary_op],feed_dict={X: batch_x, Y: batch_y, keep_prob: 0.5})

                # 每次迭代都保存日志
                step = e * num_batch + i
                summary_writer.add_summary(summary, step)
                print(step, "loss=", loss_)

                if (step) % 10 == 0:
                    # 计算准确率
                    # acc = accuracy.eval({X: test_data_x[:100], Y: test_data_y[:100], keep_prob: 1.})
                    acc = sess.run(accuracy, feed_dict={X: test_data_x[:100], Y: test_data_y[:100], keep_prob: 1.})
                    print(step, "accuracy=", acc)
                    if (acc > max_acc):
                        max_acc = acc
                        saver.save(sess, 'ckpt/nn-model.ckpt', global_step=step+1)



#所有测试数据的准确率,返回给遗传算法用的
def predict():
    return 0;


def test(path):
    # Read test picture and resize it, turn it to grey scale.
    # tst_image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    # tst_image = cv2.resize(tst_image, (64, 64))
    # tst_image = np.asarray(tst_image) / 255.0
    # tst_image = tst_image.reshape([-1, 64, 64, 1])
    # tst_image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    tst_image = cv2.imread(path)

    tst_image = preProcessImg(tst_image)
    print(tst_image)

    # cv2.imwrite("newphoto.png", tst_image)

    # feed the test picture into network and estimate probability distribution
    with tf.Session() as sess:
        output = chinese_hand_write_cnn()
        # output = bp_nn()
        predict = tf.nn.top_k(output, 10)
        saver = tf.train.Saver()
        saver.restore(sess=sess, save_path=tf.train.latest_checkpoint('ckpt-85/')) #FLAGS.checkpoint
        value_topk, index_topk = sess.run(predict, feed_dict={X: [tst_image], keep_prob: 0.5})


        index_topk = index_topk.flatten()
        value_topk = value_topk.flatten()
        print("value_topk:",value_topk)
        print("index_topk:",index_topk)
        for i in range(len(index_topk)):
            print("预测汉字是: ", char_set[index_topk[i]]," 概率是:",value_topk[i])



def main():
    print("main")
    # load_train_data()
    # load_test_data()
    # train_hand_write_nn()
    # test('testimages/yi.png')
    #test('testimages/shang.png')
    # test('testimages/xia.png')
    #test('testimages/wo2.png')
    test('testimages/ta.jpg')


if __name__ == '__main__':
    main()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

cnn手写汉字识别 的相关文章

随机推荐

  • frp服务器内网穿透设置

    内网穿透的作用 内网穿透是指在一个局域网内 也称内网 中 通过某种技术手段 将局域网内部的网络资源 如ssh服务 Web服务 数据库等 暴露到公网中 从而实现公网用户对内网资源的访问和控制 它可以使得外部用户能够访问局域网内部的设备和服务
  • ubuntu20.04安装Android Studio踩坑

    1 卸载搜狗输入法 截止现在 2020年10月7日 当搜狗输入法处于活动状态时 所有jetbrains全家桶都不能用 换用百度输入法解决问题 2 Failed to install the following Android SDK pac
  • C++之内联函数

    C 之内联函数 为什么要有内联函数 我们编写了一个小函数 它的功能是比较两个string形参的长度并返回长度较小的string的引用 挑出两个string 对象中较短的那个 返回其引用 const string shorterString
  • STM32 基于keil5的printf打印设置

    1 因为使用串口来打印 所以将fputc函数和fgetc函数放在usart c源文件中 2 在usart c源文件中添加stdio h头文件 3 打印信息常用于调试 不建议使用中断 4 在usart c源文件中添加如下代码 int fput
  • 多通道振弦数据记录仪应用桥梁安全监测的解决方案

    多通道振弦数据记录仪应用桥梁安全监测的解决方案 城市化进程的加快和交通运输的发展 桥梁作为连接城市的重要交通工具 其安全性也变得越来越重要 为了保证桥梁的安全性 需要进行定期的监测和维护 其中 多通道振弦数据记录仪是一种有效的监测手段 可以
  • 零基础在家学编程,挑战年薪10万~100万

    疫情常态化 居家常态化 房贷 车贷 生活开支常态化 如何让我们的收入也常态化 有人说 我们眼下所处的二十一世纪二十年代是世界大转折之年代 而作为一位社会普通人员 如何才能跟上社会发展 如何提高自己生存能力 如何适应社会发展状态 如何保障稳定
  • 本地电脑无法登陆路由器

    以TPLINK 路由器为例 路由器有两种登录方式 1 输入如下所示字符串 http tplogin cn 适用于本地电脑ip地址是自动获得IP的情况 如果是勾选 使用下面的IP地址 则无法登陆路由器 2 输入固定IP方式 不同品牌路由器地址
  • 异常:Could not set parameters for mapping: ParameterMapping{property='xxx', mode=XX, ······}

    1 在前端页面做添加货物的数据时 将前端的数据返回到Controller的方法 执行下一步就出现以下的异常 java lang RuntimeException org mybatis spring MyBatisSystemExcepti
  • 机器学习之聚类

    无监督学习 Learning from unlabeled unannotated data without supervision 聚类概念 the process of grouping a set of objects into cl
  • h3c 交换机 密文 有解密办法吗?

    用户名123 密码123 可逆 local user 123 password cipher c 3 3 3kK6PWyha6eFuCtZ0QfnE1jVsmBOaiw 用户名123 密码123 可逆 local user 123 pass
  • 服务器物理链路,【交换机在江湖对接案例】配置堆叠系统对接NLB服务器群集示例(通过物理链路环回方法)...

    配置堆叠系统对接NLB服务器群集示例 通过物理链路环回方法 设备通过物理链路环回方法对接NLB服务器群集简介 NLB是微软在Windows Server上开发的多服务器群集负载均衡特性 交换机与NLB服务器群集相连时 NLB服务器要求交换机
  • 浅谈web前端工程师hr面试经典问题20+

    目录 前言 一 经典灵魂20问 1 你为什么不考研 2 你如何看待加班 3 为什么选择北京 4 最能概况你自己的三个词 5 你喜爱的运动 6 你的座右铭 7 谈谈你的缺点 8 对于这项工作你有那些可预见性的困难 9 如果我录用你 你将怎样开
  • 永洪科技上榜2023年度 IDC中国FinTech 50

    8月15日 全球知名的第三方研究机构IDC发布了 2023 IDC中国FinTech 50 榜单 永洪科技凭借完善的产品服务体系 差异化的产品优势以及丰富的客户实践经验 已经连续两年荣登 IDC 中国 FinTech 50 榜单 IDC作为
  • Keras Conv1d 参数及输入输出详解

    Conv1d in channels out channels kernel size stride 1 padding 0 dilation 1 groups 1 bias True filters 卷积核的数目 即输出的维度 kerne
  • C++函数模板特化,类模板特化

    一 模版与特化的概念 1 函数模版与类模版 C 中模板分为函数模板和类模板 函数模板 是一种抽象函数定义 它代表一类同构函数 类模板 是一种更高层次的抽象的类定义 2 特化的概念 所谓特化 就是将泛型的东西搞得具体化一些 从字面上来解释 就
  • 抓包神器之Charles,常用功能都在这里了

    我们在开发网站项目的时候 我们可以通过浏览器的debug模式来看request以及response的数据 那么如果我们开发移动端项目没有网页呢 如何抓取数据呢 前几天有个做服务端的师弟跟我说他不用抓包工具 遇到问题直接debug代码 那我问
  • C#开机自动启动程序代码

    新建一个winform拖一个checkbox进来 然后设置它的changed事件 已经测试过 可以直接复制使用 private void checkBox1 CheckedChanged object sender EventArgs e
  • c语言输入一串字符统计小写字母个数,c++编程实现输入一串字符,分别统计数字字符、大、小写字母、其它字符的个数...

    满意答案 keweo4016029 推荐于 2018 04 26 采纳率 40 等级 12 已帮助 6206人 include using namespace std void main int di 0 bc 0 sc 0 el 0 数字
  • 【Bug修复】解决Idea连接不上远程服务器的Redis:redis.clients.jedisJedisConnectionException: Failed to create socket

    前言 相信出现这个问题的小伙伴已经搜了很久如何解决这个问题 然而尝试了一遍又一遍后还是报出同样的错误 步骤1 修改redis conf文件 1 注释掉原先的 bind 127 0 0 1 2 将protected mode yes 修改为n
  • cnn手写汉字识别

    import os import numpy as np import struct import PIL Image import cv2 import scipy misc from sklearn utils import shuff