基于元学习孪生网络的人脸识别算法(PC复现篇)

2023-11-06

一.说明

本文参考《Python元学习 通用人工智能的实现》第二章部分内容,修改代码使其在通用环境下跑通。本文为实际项目的前期学习汇报,后续项目也许会出现在博客或者我的b站账户上(物理系的计算机选手)

原版完整代码:动手-元学习-使用-Python/2.4 使用暹罗网络的人脸识别-检查点.ipynb at master ·sudharsan13296/动手-使用 Python 进行元学习 ·GitHub

二.软件准备

1. python3.8以上版本,tensorflow2.x以上版本

2. 准备AT&T的人脸数据库

下载链接:AT&T面部数据库_图像数据_AT&T数据集-深度学习工具类资源-CSDN下载

三.创建输入对

孪生网络要求输入值成对并带有标签,所以必须以这种方式创建数据。

方法:我们从同一个文件夹中随机取出两张图片,并将其标记为正样本对;从两个文件中分别取出一张图像,并将它们标记为负样本对。具体如下图所示:

四.算法实现

1.导入库:

# 导入库
import re
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
from keras import backend as K
from keras.layers import Activation
from keras.layers import Input, Lambda, Dense, Dropout, Convolution2D, MaxPooling2D, Flatten
from keras.optimizers import rmsprop_v2

【注】如果下载的tf是以前的版本,最后一行调用的代码可能是RMSprop

2.定义一个函数来读取输入图像:

def read_image(filename, byteorder='>'):
    # 首先将图像以RAW格式读入缓冲区
    with open(filename, 'rb') as f:
        buffer = f.read()
    # 使用regax提取图片的头部,宽度,高度以及最大值
    header, width, height, maxval = re.search(
        b"(^P5\s(?:\s*#.*[\r\n])*"
        b"(\d+)\s(?:\s*#.*[\r\n])*"
        b"(\d+)\s(?:\s*#.*[\r\n])*"
        b"(\d+)\s(?:\s*#.*[\r\n]\s)*)", buffer).groups()
    # 然后,使用np.frombuffer,该函数用于将缓冲区转换为一维数组
    return np.frombuffer(buffer,
                         dtype='u1' if int(maxval) < 256 else
                         byteorder + 'u2',
                         count=int(width)*int(height),
                         offset=len(header)
    ).reshape((int(height), int(width)))

【注】由于不同版本所带来语句可能细微变化,用下列代码测试所编写的函数是否可以成功运行,输出结果为一个图片加上一个数组(112,92)

Image.open("face_data/s1/1.pgm").show()
img = read_image('face_data/s1/1.pgm')
print(img.shape)  

3.定义一个函数来生成数据:

def get_data(size, total_sample_size):
    # 读取图像
    image = read_image('face_data/s' + str(1) + '/' + str(1) + '.pgm', 'rw+')
    # 缩减尺寸
    image = image[::size, ::size]
    # 获取新的尺寸
    dim1 = image.shape[0]
    dim2 = image.shape[1]

    count = 0
    # 初始化数组
    x_geuine_pair = np.zeros([total_sample_size, 2, 1, dim1, dim2])
    y_genuine = np.zeros([total_sample_size, 1])

    for i in range(40):
        for j in range(int(total_sample_size/40)):
            ind1 = 0
            ind2 = 0

            # 从同一个目录中读取图像
            while ind1 == ind2:
                ind1 = np.random.randint(10)
                ind2 = np.random.randint(10)

            # 读取两个图像
            img1 = read_image('face_data/s' + str(i+1) + '/' + str(ind1 + 1) + '.pgm', 'rw+')
            img2 = read_image('face_data/s' + str(i+1) + '/' + str(ind2 + 1) + '.pgm', 'rw+')

            # 缩减尺寸
            img1 = img1[::size, ::size]
            img2 = img2[::size, ::size]

            # 将图片存入初始化的Numpy数组中
            x_geuine_pair[count, 0, 0, :, :] = img1
            x_geuine_pair[count, 1, 0, :, :] = img2

            # 分配标签是1
            y_genuine[count] = 1
            count += 1

    count = 0
    x_imposite_pair = np.zeros([total_sample_size, 2, 1, dim1, dim2])
    y_imposite = np.zeros([total_sample_size, 1])

    for i in range(int(total_sample_size/10)):
        for j in range(10):

            while True:
                ind1 = np.random.randint(40)
                ind2 = np.random.randint(40)
                if ind1 != ind2:
                    break

            img1 = read_image('face_data/s' + str(ind1+1) + '/' + str(j + 1) + '.pgm', 'rw+')
            img2 = read_image('face_data/s' + str(ind2+1) + '/' + str(j + 1) + '.pgm', 'rw+')

            img1 = img1[::size, ::size]
            img2 = img2[::size, ::size]

            x_imposite_pair[count, 0, 0, :, :] = img1
            x_imposite_pair[count, 1, 0, :, :] = img2

            # 分配标签0
            y_imposite[count] = 0
            count += 1

    X = np.concatenate([x_geuine_pair, x_imposite_pair], axis=0)/255
    Y = np.concatenate([y_genuine, y_imposite], axis=0)

    return X, Y

这里同样测试一下,测试代码如下(输出结果见备注):

X, Y = get_data(size, total_sample_size)
print(X.shape, Y.shape)
# # (20000, 2, 1, 56, 46) (20000, 1)

4.构建孪生网络:

# 构建孪生网络
def build_base_network(input_shape):
    # 容器,在这个上面进行编译
    seq = Sequential()
    nb_filter = [6, 12]
    kernel_size = 3

    # 卷积层1
    seq.add(Convolution2D(nb_filter[0], kernel_size, kernel_size, input_shape=input_shape,
                          border_mode='valid', dim_ordering='th'))
    # seq.add(Conv2D(nb_filter, kernel_size=(3, 3), ))
    # 激活函数
    seq.add(Activation('relu'))
    # 最大池化层
    seq.add(MaxPooling2D(pool_size=(2, 2)))
    seq.add(Dropout(.25))

    # 卷积层2
    seq.add(Convolution2D(nb_filter[1], kernel_size, kernel_size, border_mode='valid', dim_ordering='th'))
    # 激活函数
    seq.add(Activation('relu'))
    # 最大池化层
    seq.add(MaxPooling2D(pool_size=(2, 2), dim_ordering='th'))
    seq.add(Dropout(.25))

    # 扁平化层
    seq.add(Flatten())
    seq.add(Dense(128, activation='relu'))
    seq.add(Dropout(0.1))
    seq.add(Dense(50, activation='relu'))
    return seq

【注】convolution这个单元在新版本的keras中是被替换成conv系列的,这里建议去官网学习一下详细的参数含义,我只能说要改的还是蛮多的。

我这里将做如下修改:

 

5.接下来把图像对输入基网络,它将返回嵌入,即特征向量:
 

input_dim = x_train.shape[2:]
img_a = Input(shape=input_dim)
img_b = Input(shape=input_dim)

base_network = build_base_network(input_dim)

feat_vecs_a = base_network(img_a)
feat_vecs_b = base_network(img_b)

这里要注意的是,feat_vecs指的是图像对的特征向量,接下来把这些特征向量输入能量函数,计算它们之间的距离,并使用欧氏距离作为能量函数:

# 能量函数
def euclidean_distance(vects):
    x, y = vects
    z = K.sqrt(K.sum(K.square(x-y), axis=1, keepdims=True))
    return z


# 计算距离
def eucl_dist_output_shape(shapes):
    shape1, shape2 = shapes
    return (shape1[0],1)


distance = Lambda(euclidean_distance, output_shape=eucl_dist_output_shape([feat_vecs_a, feat_vecs_b]))

现在将轮数设置为13,并使用rmsprop进行优化,之后定义模型:

# 轮数
epochs = 3

# 优化
rms = rmsprop_v2

# 定义模型
model = Model(input=[img_a, img_b], output=distance)

6.定义损失函数:

# 定义损失函数
def contrastive_loss(y_true, y_pred):
    margin = 1
    z = K.mean(y_true * K.square(y_pred) + (1-y_true) * K.square(K.maximum(margin - y_pred, 0)))
    return z

7.训练模型:

model.compile(loss=contrastive_loss, optimizer=rms)


# 训练模型
img_1 = x_train[:, 0]
img_2 = x_train[:, 1]

model.fit([img_1, img_2], y_train, validation_split=.25, batch_size=128, verbose=2, epochs=epochs)
# model.summary()

 

 

 

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于元学习孪生网络的人脸识别算法(PC复现篇) 的相关文章

随机推荐

  • 【kali】28 提权——读取windows本地密码:pwddump、WCE、fgdump、mimikatz

    这里写自定义目录标题 一 抓包嗅探 二 键盘记录本地密码 三 查看本地缓存密码 1 浏览器查看密码 2 密码恢复工具 3 使用 Pwdump 查看 windows 本地登录密码 4 了解windows身份认证过程 5 WCE WINDOWS
  • Elasticsearch性能优化

    问题导读1 集群规划有哪些优化措施 2 磁盘该如何选择 3 内存该如何分配中 4 索引优化有哪些方法 5 数据模型优化包含哪些内容 0 题记Elasticsearch性能优化的最终目的 用户体验爽 Elasticsearch的爽点就是 快
  • Java21天打卡day19-异常

    异常 异常分类 编译时异常 程序编译时的异常例子 IO异常 SQL异常 运行时异常的区别 程序在运行时出现的异常 会自动抛出该异常 异常处理 try catch finally处理异常 throws 和 throw 的区别 throws是用
  • orcad capture学习笔记---3.DRC规则设置及检查

    1 确定版本 我用的cadence的版本是16 6 想要查看自己的orcad capture版本可以对桌面图标 右键 属性 进行查看 2 进入DRC设置界面 如下图 依次选中 dsn Tools Design Rules Check 然后会
  • 正则表达式常用的函数及用法说明

    正则表达式 正则表达式 Regular Expression 简称regex或RegExp 是一种用于描述字符串匹配规则的工具 它由一些特殊字符和普通字符组成 用于匹配符合特定模式的字符串 正则表达式可以用来实现各种功能 如搜索 匹配 替换
  • 在pycharm中导入anaconda中已安装好的库和包时出现的问题

    1 已安装好anaconda 且一些常用的包比如opencv matplotlib numpy都已安装配置完成 2 想在pycharm中直接导入anaconda里的已安装好的包 拿来使用 这样方便 因为不需要重新在pycharm里下载安装一
  • 解决MediaPlayer异常: Should have subtitle controller already set

    如果需要源码讲解或者其他问题可以私信找我 原因分析 1 MediaPlayer Should have subtitle controller already set 首先出现的这个问题在API19与在API21以上是有区别的 API 21
  • 【Java-IO】如何理解 Java 中的 IO 流?

    文章目录 1 概述 2 流的分类 1 输入流和输出流 2 字节流和字符流 3 节点流和处理流 1 概述 Java 的 IO 流是使用 Java 语言实现输入 输出的基础 可以通过调用 java io 包内的 API 很方便的实现数据的输入
  • 哈希表(限定版)

    目录 今日良言 既然没有女朋友 那就安心敲代码 一 效果展示 1 添加员工 2 显示员工 3 查找员工 4 删除员工 二 实现思路 1 总体思路分析 2 针对员工相关操作分析 三 完整代码 今日良言 既然没有女朋友 那就安心敲代码 七夕没情
  • java中String初始化的两种方式(图解)

    java中创建并初始化一个String对象 最常见的方式有两种 String str new String XXX String str XXX 前者是每一次new一个新对象 都会从堆内存中重新生成一个新的对象 后者则会在栈中创建一个对象引
  • Unity3D关于两个物体直接用圆柱进行连接画线(简单画线连接)

    最近做的东西需要用圆柱画线 网上找了些 没找到合适的 所以自己简单写了一个 这个函数只需要输入起始点和终点即可 材质可以自己调整 void DrawLS GameObject startP GameObject finalP Vector3
  • 从功能测试转型测试开发,薪资涨了20K,1000字讲述转型必经之路...

    身处职场之中 犹如逆水行舟不进则退 想要不被后浪拍死在沙滩上 就要不断学习新知识 接受新事物 要得到更好的发展 就要紧跟发展趋势 不断转型才能保持竞争力 在职场中占有一席之地 转型不是一件容易的事 涉及到转型 革新 就要突破现有的框架 必然
  • dreamweaver 正则表达式为属性值加上双引号_IT兄弟连 HTML5教程 HTML5表单 新增的表单属性3...

    9 novalidate novalidate是属性规定在提交表单时不应该验证form和input域 novalidate属性适用于的类型有 text search url telephone email password date pic
  • webService淘汰了吗?

    当代开发者们已经很少见到相关的webService开发了 那么是该技术已经被淘汰了吗 先让我们来看看其和http接口的优劣吧 这里着重说webService 该服务协议为SOAP 简单对象访问协议 说白了就是http POST的一个专用版本
  • DVWA-命令注入

    命令注入漏洞的函数 system exec passthru shell exec 与shell exec 功能相同 一 low 1 分析源码 使用的函数是shell exec 2 验证 3 漏洞测试 前面命令的输出结果作为后面命令的输入
  • 随机森林和神经网络有什么区别?

    随机森林和神经网络这两种广泛使用的机器学习算法有什么区别呢 我们什么时候应该使用神经网络 什么时候又应该使用随机森林 随机森林与神经网络哪个更好 这是一个常见问题 答案其实也非常简单 视情况而定 调皮 一起来看看何时使用随机森林好以及何时使
  • Golang大坑之循环goroutine闭包调用

    前言 回顾整个2022 突然发现我一篇博客都没写 趁着还没2022还没过去 赶紧水一篇博客 分享一下我最近学习到的一些东西 这次的主题是 Golang大坑之循环goroutine闭包调用 大家就当小故事来看吧 小美又写了bug 仔细看 这个
  • jmeter-Java关于MD5加密方法 以及16位32位互转

    MD5即Message Digest Algorithm 5 信息 摘要算法5 用于确保信息传输完整一致 是计算机广泛使用的杂凑算法之一 又译摘要算法 哈希算法 主流编程语言普遍已有MD5实现 将数据 如汉字 运算为另一固定长度值 是杂凑算
  • GDB调试进程方法

    简单易懂的gdb调试进程方法 更新中 1 首先找出需要调试的进程PID 命令 ps ef grep 进程名 2 gdb attach PID 中断进程 并附着进程 接下来就可以调试了 3 设置断点 break 函数名 文件名 行号 比如 b
  • 基于元学习孪生网络的人脸识别算法(PC复现篇)

    一 说明 本文参考 Python元学习 通用人工智能的实现 第二章部分内容 修改代码使其在通用环境下跑通 本文为实际项目的前期学习汇报 后续项目也许会出现在博客或者我的b站账户上 物理系的计算机选手 原版完整代码 动手 元学习 使用 Pyt