人工智能学习笔记五——孪生神经网络

2023-05-16

本文将用孪生神经网络模型,对手写数字集minist进行相似度比较,用的框架是keras。如果还不清楚神经网络,可以看一下这篇文章神经网络 (caodong0225.github.io)

MNIST是一个手写体数字的图片数据集,该数据集来由美国国家标准与技术研究所发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,50%来自人口普查局的工作人员。该数据集的收集目的是希望通过算法,实现对手写数字的识别。

训练集一共包含了 60,000 张图像和标签,而测试集一共包含了 10,000 张图像和标签。测试集中前5000个来自最初NIST项目的训练集.,后5000个来自最初NIST项目的测试集。前5000个比后5000个要规整,这是因为前5000个数据来自于美国人口普查局的员工,而后5000个来自于大学生。

下载地址:caodong0225.github.io/minist.zip at master · caodong0225/caodong0225.github.io

该数据集自1998年起,被广泛地应用于机器学习和深度学习领域,用来测试算法的效果,例如线性分类器(Linear Classifiers)、K-近邻算法(K-Nearest Neighbors)、支持向量机(SVMs)、神经网络(Neural Nets)、卷积神经网络(Convolutional nets)等等。

图1(minist部分手写数据集)

而Keras是一个由Python编写的开源人工神经网络库,可以作为Tensorflow、Microsoft-CNTK和Theano的高阶应用程序接口,进行深度学习模型的设计、调试、评估、应用和可视化 。

Keras在代码结构上由面向对象方法编写,完全模块化并具有可扩展性,其运行机制和说明文档有将用户体验和使用难度纳入考虑,并试图简化复杂算法的实现难度。Keras支持现代人工智能领域的主流算法,包括前馈结构和递归结构的神经网络,也可以通过封装参与构建统计学习模型。在硬件和开发环境方面,Keras支持多操作系统下的多GPU并行计算,可以根据后台设置转化为Tensorflow、Microsoft-CNTK等系统下的组件。因而本文用keras做为框架。

至于孪生神经网络(Siamese neural network),又名双生神经网络,是基于两个人工神经网络建立的耦合构架。孪生神经网络以两个样本为输入,输出其嵌入高维度空间的表征,以比较两个样本的相似程度。狭义的孪生神经网络由两个结构相同,且权重共享的神经网络拼接而成。广义的孪生神经网络,或“伪孪生神经网络(pseudo-siamese network)”,可由任意两个神经网拼接而成。孪生神经网络通常具有深度结构,可由卷积神经网络、循环神经网络等组成。

图2孪生神经网络示意图

所谓权值共享就是当神经网络有两个输入的时候,这两个输入使用的神经网络的权值是共享的(可以理解为使用了同一个神经网络)。很多时候,我们需要去评判两张图片的相似性,比如比较两张人脸的相似性,我们可以很自然的想到去提取这个图片的特征再进行比较,自然而然的,我们又可以想到利用神经网络进行特征提取。如果使用两个神经网络分别对图片进行特征提取,提取到的特征很有可能不在一个域中,此时我们可以考虑使用一个神经网络进行特征提取再进行比较。这个时候我们就可以理解孪生神经网络为什么要进行权值共享了。

孪生神经网络有两个输入(Input1 and Input2),利用神经网络将输入映射到新的空间,形成输入在新的空间中的表示。通过Loss的计算,评价两个输入的相似度。 

图3孪生神经网络示意图

映射的方法有多种,常见的映射方法为平方差映射和绝对值映射。对于输入的两张图片,在通过共享权重的神经网络的特征提取后,会得到两组一维,大小为N的特征向量W1W2。设通过映射后得到的一维向量为W3,那么平方差映射的公式为:

绝对值映射的公式为:

至此,孪生神经网络共享权重的部分就结束了,对于新的向量W3,既可以继续进行神经网络操作,也可以将向量的每个值加起来后开平方或者取平均值做为损失函数,视具体情况而定。在本例中,我们采取后者方法,我们规定这个损失函数叫做对比损失函数(contrasive loss)。公式为:

其中y表示样本标签,也就是1或者0,表示输入的两张图片是否为同一种类型的图片,如果是就为1,否则为0。margin表示阈值,因为当输入图片为不同类型时,d的值会非常大,为了防止d过大导致损失函数变化不均匀,所以设置阈值,一般margin的值取1。

观察这个式子可以发现,当输入图片为相同类型时,损失函数就是MSE损失函数开平方,神经网络会把W1W2的值调整得尽量相等,从而d的值越小,损失函数的值越小。当输入图片为不同类型时,神经网络会把W1W2的值调整得尽量不相等,从而d的值越大,如果d的值超过了阈值margin,那么损失函数的值就为0。

为了衡量模型在训练集上的准确度,对于Accuracy的计算也需要设计一个函数,在本例中,我们规定,当d>0.5时,神经网络将认定两张图片为不同图片,d<0.5时,神经网络将认定两张图片为同一图片。可视具体情况调整划分的值。

准确度的代码如下:

import keras.backend as K  

def accuracy(y_true, y_pred): # Tensor上的操作  

    return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))  

损失函数的代码如下:

import keras.backend as K  

def contrastive_loss(y_true, y_pred):  

     margin = 1  

     sqaure_pred = K.square(y_pred)  

     margin_square = K.square(K.maximum(margin - y_pred, 0))  

     return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)

神经网络的结构构造如下:

图4孪生神经网络构造图

   完整的代码如下:

#coding:gbk  

from keras.layers import Input,Dense  

from keras.layers import Flatten,Lambda,Dropout  

from keras.models import Model  

import keras.backend as K  

from keras.models import load_model  

import numpy as np  

from PIL import Image  

import glob  

import matplotlib.pyplot as plt  

from PIL import Image  

import random  

from keras.optimizers import Adam,RMSprop  

import tensorflow as tf  

def create_base_network(input_shape):  

    image_input = Input(shape=input_shape)  

    x = Flatten()(image_input)  

    x = Dense(128, activation='relu')(x)  

    x = Dropout(0.1)(x)  

    x = Dense(128, activation='relu')(x)  

    x = Dropout(0.1)(x)  

    x = Dense(128, activation='relu')(x)  

    model = Model(image_input,x,name = 'base_network')  

    return model  

def contrastive_loss(y_true, y_pred):  

     margin = 1  

     sqaure_pred = K.square(y_pred)  

     margin_square = K.square(K.maximum(margin - y_pred, 0))  

     return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)  

def accuracy(y_true, y_pred): # Tensor上的操作  

    return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))  

def siamese(input_shape):  

    base_network = create_base_network(input_shape)  

    input_image_1 = Input(shape=input_shape)  

    input_image_2 = Input(shape=input_shape)  

  

    encoded_image_1 = base_network(input_image_1)  

    encoded_image_2 = base_network(input_image_2)  

  

    l2_distance_layer = Lambda(  

        lambda tensors: K.sqrt(K.sum(K.square(tensors[0] - tensors[1]), axis=1, keepdims=True))  

        ,output_shape=lambda shapes:(shapes[0][0],1))  

    l2_distance = l2_distance_layer([encoded_image_1, encoded_image_2])  

      

    model = Model([input_image_1,input_image_2],l2_distance)  

      

    return model  

def process(i):  

    img = Image.open(i,"r")  

    img = img.convert("L")  

    img = img.resize((wid,hei))  

    img = np.array(img).reshape((wid,hei,1))/255  

    return img  

#model = load_model("testnumber.h5",custom_objects={'contrastive_loss':contrastive_loss,'accuracy':accuracy})  

wid=28  

hei=28  

model = siamese((wid,hei,1))  

imgset=[[],[],[],[],[],[],[],[],[],[]]  

for i in glob.glob(r"train_images\*.jpg"):  

    imgset[int(i[-5])].append(process(i))  

size = 60000  

  

r1set = []  

r2set = []  

flag = []  

for j in range(size):  

    if j%2==0:  

        index = random.randint(0,9)  

        r1 = imgset[index][random.randint(0,len(imgset[index])-1)]  

        r2 = imgset[index][random.randint(0,len(imgset[index])-1)]  

        r1set.append(r1)  

        r2set.append(r2)  

        flag.append(1.0)  

    else:  

        index1 = random.randint(0,9)  

        index2 = random.randint(0,9)  

        while index1==index2:  

            index1 = random.randint(0,9)  

            index2 = random.randint(0,9)  

        r1 = imgset[index1][random.randint(0,len(imgset[index1])-1)]  

        r2 = imgset[index2][random.randint(0,len(imgset[index2])-1)]  

        r1set.append(r1)  

        r2set.append(r2)  

        flag.append(0.0)  

r1set = np.array(r1set)  

r2set = np.array(r2set)  

flag = np.array(flag)  

model.compile(loss = contrastive_loss,  

            optimizer = RMSprop(),  

            metrics = [accuracy])  

history = model.fit([r1set,r2set],flag,batch_size=128,epochs=20,verbose=2)  

# 绘制训练 & 验证的损失值  

plt.figure()  

plt.subplot(2,2,1)  

plt.plot(history.history['accuracy'])  

plt.title('Model accuracy')  

plt.ylabel('Accuracy')  

plt.xlabel('Epoch')  

plt.legend(['Train'], loc='upper left')  

plt.subplot(2,2,2)  

plt.plot(history.history['loss'])  

plt.title('Model loss')  

plt.ylabel('Loss')  

plt.xlabel('Epoch')  

plt.legend(['Train'], loc='upper left')  

plt.show()  

model.save("testnumber.h5")

  

   训练过程图如下:

图5训练过程图

图6训练过程图

   评估的展示代码如下:

import glob  

from PIL import Image  

import random  

def process(i):  

    img = Image.open(i,"r")  

    img = img.convert("L")  

    img = img.resize((wid,hei))  

    img = np.array(img).reshape((wid,hei,1))/255  

    return img  

def contrastive_loss(y_true, y_pred):  

     margin = 1  

     sqaure_pred = K.square(y_pred)  

     margin_square = K.square(K.maximum(margin - y_pred, 0))  

     return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)  

def accuracy(y_true, y_pred): # Tensor上的操作  

    return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))  

def compute_accuracy(y_true, y_pred):  

    pred = y_pred.ravel() < 0.5  

    return np.mean(pred == y_true)  

imgset=[]  

wid = 28  

hei = 28  

imgset=[[],[],[],[],[],[],[],[],[],[]]  

for i in glob.glob(r"test_images\*.jpg"):  

    imgset[int(i[-5])].append(process(i))  

model = load_model("testnumber.h5",custom_objects={'contrastive_loss':contrastive_loss,'accuracy':accuracy})  

for i in range(50):  

    if random.randint(0,1)==0:  

        index=random.randint(0,9)  

        r1 = random.randint(0,len(imgset[index])-1)  

        r2 = random.randint(0,len(imgset[index])-1)  

        plt.figure()  

        plt.subplot(2,2,1)  

        plt.imshow((255*imgset[index][r1]).astype('uint8'))  

        plt.subplot(2,2,2)  

        plt.imshow((255*imgset[index][r2]).astype('uint8'))  

        y_pred = model.predict([np.array([imgset[index][r1]]),np.array([imgset[index][r2]])])  

        print(y_pred)  

        plt.show()  

    else:  

        index1 = random.randint(0,9)  

        index2 = random.randint(0,9)  

        while index1==index2:  

            index1 = random.randint(0,9)  

            index2 = random.randint(0,9)  

        r1 = random.randint(0,len(imgset[index1])-1)  

        r2 = random.randint(0,len(imgset[index2])-1)  

        plt.figure()  

        plt.subplot(2,2,1)  

        plt.imshow((255*imgset[index1][r1]).astype('uint8'))  

        plt.subplot(2,2,2)  

        plt.imshow((255*imgset[index2][r2]).astype('uint8'))  

        y_pred = model.predict([np.array([imgset[index1][r1]]),np.array([imgset[index2][r2]])])  

        print(y_pred)  

        plt.show()

图7 图片相似度比较

图8 图片相似度比较

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

人工智能学习笔记五——孪生神经网络 的相关文章

  • ROS系统 创建工作空间与功能包

    ROS 学习目标 xff1a 学习内容 xff1a 使用环境 操作步骤 xff1a 基本命令 二 使用步骤创建工作空间编译工作空间创建功能包 使用C 43 43 执行程序编写源文件编辑功能包下的 Cmakelist txt文件修改目标链接库
  • 计算机网络与互联网(了解)

    文章目录 1 0 相关术语 Terms 1 1 什么是互联网1 2 互联网发展史1 3 网络体系结构1 3 1 网络边缘 Network Edge 1 3 2 网络核心 Network Core 1 3 3 接入网络与物理媒体 1 4 De
  • python库的安装、卸载和查询

    python库的安装 卸载和查询 安装库 方法1 xff1a pip install xxx 如图1所示 xff0c 在命令提示符窗口输入pip install xxx xff0c 即可在线安装指定库 xff0c 如输入pip instal
  • 计算机三级 数据库技术 前言

    考试内容及要求 1 掌握数据库技术的基本概念 原理 方法和技术 2 能够使用SQL语言实现数据库操作 3 具备数据库系统安装 配置及数据库管理与维护的基本技能 4 掌握数据库管理与维护的基本方法 5 掌握数据库性能优化的基本方法 6 了解数
  • 计算机三级 数据库技术(Chapter 2)

    第二章 xff1a 需求分析 主要内容 xff1a 需求分析的相关概念以及主要方法需求建模方法案例分析 Class 1 需求分析 1 需求分析的概念与意义 需求 xff1a 需求是指用户对软件的功能和性能的要求 就是用户的要求内容以及对要求
  • 微分几何 Class 1 向量空间

    微分几何 作为一名大三下的数学专业学生 xff0c 我本学期将实时将我所感兴趣的一门课微分几何笔记以及一些总结同步到我的博客上 xff0c 以便进行学习总结与自我督促 参考书 微分几何 苏步青 xff0c 胡和生 xff08 2016 xf
  • 微分几何 Class 2 欧氏空间

    欧氏空间 在上完上一节课之后 xff0c 我才意识到 xff0c 欧氏空间和欧氏向量空间原来不是同一个东西 但是在介绍欧氏空间之前 xff0c 我们首先来了解一下什么叫做仿射空间 Part One 仿射空间 定义 xff1a 仿射空间 设
  • 歌评-《Rex Incognito 尘世闲游》-陈致逸

    时隔一周时间了 xff0c 终于又找到了时间来更新我的歌评内容 虽然身被学校关了起来 xff0c 但是心里还是在歌曲的梦幻世界中畅游hhh 今天我们来听的歌曲也是 The Stellar Moments 闪耀的群星OST专辑中的一首 xff
  • 将Maven的Docker镜像修改为国内源

    声明 xff1a 本文CSDN作者原创投稿文章 xff0c 未经许可禁止任何形式的转载 xff0c 原文链接 前提 在使用Dockerfile构建镜像时 xff0c Maven的Docker镜像内置的是官方源 xff0c 使用起来下载速度太
  • 我看文二代

    文二代 文二代 xff0c 其实就是人们常说的文坛的后辈子女 xff0c 即父母是搞文学的作家 xff0c 子女也和文学脱不了干系 前一段 xff0c 贾平凹的女儿贾浅浅因为其浅浅体诗歌以及部分奇奇怪怪的内容上了热搜 被许多网友痛骂 对此
  • 码农多打拼5年对生子的影响

    码农多打拼五年对生子决策的影响 首先我们确定在这个问题中要处理的对象 xff1a 单个个体 他会有哪些属性呢 xff1f 1 退休年限 2 生活状态 我们要分析的是一个事件对生子数目的影响 xff0c 其实在现当代 xff0c 因为过大的工
  • 微分几何 Class 3 曲线,曲率与挠率

    正则曲线 什么是曲线 在空间中 xff0c 我们会见到各种各样的形状 xff0c 但无论什么形状 xff0c 其根本还是由点和线来构成的 xff0c 这里我们的线是一个直观的理解 xff0c 就是一条直直的 xff0c 有的也是弯的那样的
  • 随机过程 番外篇(随机拟合作业解答)

    一晚上写了三道随机过程的随机模拟的代码 xff0c 分享出来给大家做个参照 1 如果一个随机变量服从的是期望为 mu xff0c 协方差矩阵为 Sigma
  • 小云的生日史书

    小云的生日史 xff1a 生日10月21日 前三岁历史暂且不记录 xff0c 史前时期 xff0c 资料不详 四岁生日 xff1a 白天去了姥姥姥爷家去玩 xff0c 他们都对我的生日表示了祝福 下午便回到了奶奶家里 xff0c 等着生日p
  • 信息论篇-第一次上机作业,你好!

    信息论第一次上机作业 1 图像信源熵的求解 读入一幅图像 实现求解图片信源的熵 span class token triple quoted string string 1 图像信源熵的求解 读入一幅图像 实现求解图片信源的熵 span s
  • 媒体科创部 学习分享 非线性规划

    非线性规划 哇哈哈 xff0c 这次轮到我来讲了 xff0c 虽然很懒 xff0c 但是还是来写博客了 这次我们要谈的东西是非线性规划 非线性规划 非线性规划的定义 目标or限制中包含着非线性函数 线性规划与非线性规划的区别 如果线性规划的
  • 多玩家赌徒输光问题

    在随机过程课堂上我们考虑了赌徒输光问题 知道了成本和概率变化的情况将对赌徒甲和赌徒乙的赌博结果产生了怎样的影响 考虑的问题主要有以下几个方面 本金对胜负的影响 概率对胜负的影响 本金对持续轮数的影响 概率对持续轮数的影响 对上述问题的综合考
  • Spring Boot(Maven)+Docker打包

    声明 xff1a 本文CSDN作者原创投稿文章 xff0c 未经许可禁止任何形式的转载 xff0c 原文链接 本文可以实现 xff1a 将Spring Boot项目从GitHub clone到服务器上后 xff0c 一条命令直接完成依赖下载
  • 解决静态资源文件js/css缓存问题(超详细总结版)

    什么是静态资源文件 顾名思义 xff0c 静态资源文件就是js css img等非服务器动态运行生成的文件 xff0c 统称为静态 资源文件 为什么要缓存静态资源文件 静态资源文件是基本不会改变的 xff0c 没必要每次都从服务器中获取 也
  • 微分几何工具代码

    span class token keyword import span math span class token keyword from span sympy span class token keyword import span

随机推荐

  • 【璀璨数海】第一期 隐函数定理

    隐函数定理 鸽了好久了 xff0c 大三生活真的好累啊 xff01 quad quad 前两天夏令营面试的时候被问到了隐函数定理 xff0c 特此专门写一篇博文来重新复习讲解一下隐函数定理的内容 定理内容 xff1a 假定
  • hive安装与配置

    hive的安装与配置 hive介绍 xff1a Hive是基于Hadoop的一个数据仓库工具 xff0c 可以将结构化的数据文件映射为一张数据库表 xff0c 并提供类SQL查询功能 准备工作 xff1a hadoop集群成功部署卸载自带的
  • 原生spark与pyspark使用比较

    pyspark与原生spark xff08 scala xff09 比较 在学习完spark这个优秀的计算框架后 xff0c 因为当时的学习使用了python api对spark进行交互 xff0c 编写spark的原生语言为sacla x
  • Spark基础测试题

    因最近学习了scala重温spark xff0c 本篇主要是spark rdd的基础编程题 原题目地址 xff1a 题目地址 数据准备 本题所需的数据 data txt 数据结构如下依次是 xff1a 班级 姓名 年龄 性别 科目 成绩 1
  • Spark基础练习系列

    因最近学习了scala重温spark xff0c 本篇主要是spark sql与rdd的基础编程题 第一部分SparkRDD xff1a 原题目地址 xff1a 题目地址 数据准备 本题所需的数据 data txt 数据结构如下依次是 xf
  • sparkstream消费kafka序列化报错

    本篇介绍在window运行环境下 xff0c 使用spark消费kafka数据遇到的几个坑 调试环境IDEA 依赖 span class token operator lt span dependency span class token
  • Hadoop的安装和使用

    前言 xff1a 这个Hadoop的安装和使用操作起来很容易出错 xff0c 反正各种的问题 xff0c 所以在实验过程中需要细心 重复 xff0c 有的时候是机器的问题 xff0c 还有配置的问题 下面我讲一下我遇到的坑 xff01 第3
  • 树莓派 | 解决VNC Viewer无法连接显示问题

    如果觉得本篇文章对您的学习起到帮助作用 xff0c 请 点赞 43 关注 43 评论 xff0c 留下您的足迹 x1f4aa x1f4aa x1f4aa VNC Viewer是一个很不错的远程桌面应用 xff0c 但是我们在树莓派中使用时
  • kubeadm部署k8s,coredns一直处于containercreating状态failed to find plugin “flannel“ in path [/opt/cni/bin]]

    问题 xff1a coredns始终处于containercreating状态 coredns镜像拉取不下来 xff0c 只能手动拉去之后修改tag进行解决这个问题 xff0c 具体步骤如下 xff1a span class token n
  • Docker Compose部署Springboot+Mysql项目

    声明 xff1a 本文CSDN作者原创投稿文章 xff0c 未经许可禁止任何形式的转载 xff0c 原文链接 在上一篇文章Spring Boot Maven 43 Docker打包中 xff0c 我们实现了将Springboot项目源代码一
  • 华为交换机配置LACP模式链路聚合

    文章目录 1 拓扑图2 任务描述3 SwA配置4 SwB配置5 查看配置6 普通模式链路聚合演示 https blog csdn net qq 45042462 article details 120972306 1 拓扑图 2 任务描述
  • Linux 或 树莓派 4B 使用 apt 或 pip 安装 scipy

    下面的安装过程都是在树莓派4B上安装成功的 xff0c 记录一下 xff0c 仅供参考 python 3 7 使用 apt 安装 注 xff1a 这种好像python版本只能在3 8以下 xff0c 其他版本也可尝试 sudo apt ge
  • 树莓派4B 安装 sklearn

    本文记录在树莓派4B中安装sklearn库的步骤以及安装时遇到的问题 安装步骤 sudo pip3 install numpy 61 61 1 23 5 sudo apt get install python3 numpy python3
  • navicat连接mysql报错1251的解决方法

    navicat连接mysql报错1251的解决方法 1 新安装的mysql8 xff0c 使用破解版的navicat连接的时候一直报错 xff0c 如图所示 xff1a 2 网上查找原因发现是mysql8 之前的版本中加密规则是mysql
  • 一个简单的flask实例

    Flask是python编写的轻量级的web框架 span class token comment 导入Flask类 span span class token keyword from span flask span class toke
  • 基础命令(四)

    LINUX基础命令 xff08 四 xff09 一 Tail命令 1 tail使用方法 tail命令用途是依照要求将指定的文件的最后部分输出到标准设备 xff0c 通常是终端 xff0c 通俗讲来 xff0c 就是把某个档案文件的最后几行显
  • Snipaste常用快捷键(详细总结)

    Snipaste快捷键 xff08 详细总结 xff09 全局快捷键 全局操作截屏F1贴图F3退出当前截图Esc截屏并自动复制Ctrl 43 F1隐藏 显示所有贴图Shift 43 F3切换到另一组贴图Ctrl 43 F3 鼠标贴图相关操作
  • 4位数值比较器电路

    4位数值比较器电路 题目描述 xff1a 使用门级描述方式 xff0c 实现4位数值比较器 某4位数值比较器的功能如下表 96 timescale 1ns 1ns module comparator 4 input 3 0 A input
  • Maven项目pom.xml project标签爆红解决方法

    今天在打开项目的时候 xff0c 发现了一个Maven项目的问题 xff0c 在Maven项目的pom xml文件中 xff0c project标签爆出了一个错误 parent relativePath of POM com hrp spr
  • 人工智能学习笔记五——孪生神经网络

    本文将用孪生神经网络模型 xff0c 对手写数字集minist进行相似度比较 xff0c 用的框架是keras 如果还不清楚神经网络 xff0c 可以看一下这篇文章 神经网络 caodong0225 github io MNIST 是一个手