pytorch——实现自编码器和变分自编码器

2023-11-06

本文只设计简单的介绍和pytorch实现,数学原理部分请移步知乎“不刷知乎”大佬的文章:半小时理解变分自编码器
本文部分观点引自上述知乎文章。

数据降维

降维是减少描述数据的特征数量的过程。可以通过选择(仅保留一些现有特征)或通过提取(基于旧特征组合来生成数量更少的新特征)来进行降维。降维在许多需要低维数据(数据可视化,数据存储,繁重的计算…)的场景中很有用。

主成分分析(PCA)

PCA的想法是构建m个新的独立特征,这些特征是n个旧特征的线性组合,并使得这些新特征所定义的子空间上的数据投影尽可能接近初始数据(就欧几里得距离而言)。换句话说,PCA寻找初始空间的最佳线性子空间(由新特征的正交基定义),以使投影到该子空间上的近似数据的误差尽可能小。

自编码器(AE)

简单来说就是使用神经网络做编码器(Encoder)和解码器(Decoder)。

  • 输入和输出的维度是一致的,保证能够重建
  • 中间有一个neck,可以升维或者降维(常用与降维)

在这里插入图片描述
自编码器的缺点:

  • 缺乏规则性:隐空间中缺乏可解释和可利用的结构
  • 自编码器的高自由度使得可以在没有信息损失的情况下进行编码和解码(尽管隐空间的维数较低)但会导致严重的过拟合,这意味着隐空间的某些点将在解码时给出无意义的内容。

变分自编码器(VAE)

隐空间的规则性可以通过两个主要属性表示:

  • 连续性(continuity,隐空间中的两个相邻点解码后不应呈现两个完全不同的内容);
  • 完整性(completeness,针对给定的分布,从隐空间采样的点在解码后应提供“有意义”的内容)。

简单来说,为了保证隐空间的规则性,VAE的编码器不是将输入编码为隐空间中的单个点,而是将其编码为隐空间中的概率分布。然后解码时按照此概率分布从隐空间中采样进行解码。训练过程:

  • 首先,将输入编码为在隐空间上的分布;
  • 第二,从该分布中采样隐空间中的一个点;
  • 第三,对采样点进行解码并计算出重建误差;
  • 最后,重建误差通过网络反向传播。

在这里插入图片描述

在这里插入图片描述如何计算KL散度:
在这里插入图片描述
将输入编码为具有一定方差而不是单个点的分布的原因是这样可以非常自然地表达隐空间规则化:编码器返回的分布被强制接近标准正态分布。

pytorch实现

本节实现AE和VAE对MNIST数据集的编码与解码(重现)。

AE

实现自编码器网络结构

'''
定义自编码器网络结构
'''

import torch
from torch import nn


class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()

        # [b, 784] => [b, 20]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )

        # [b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )


    def forward(self, x):
        """
        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.size(0)
        # flatten(打平)
        x = x.view(batchsz, 784)
        # encoder
        x = self.encoder(x)
        # decoder
        x = self.decoder(x)
        # reshape
        x = x.view(batchsz, 1, 28, 28)

        return x

实现AE对MNIST数据集的处理

'''
此处需要安装并开启visdom
安装:pip install visdom
开启:python -m visdom.server
'''

import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision import transforms, datasets

from ae import AE

import visdom


def main():
    '''import mnist dataset'''
    mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    '''show shape of data'''
    x, label_unuse = iter(mnist_train).next()
    print('x:', x.shape)  # torch.Size([32, 1, 28, 28])

    '''定义神经网络相关内容'''
    device = torch.device('cuda')
    model = AE().to(device)
    # model = VAE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    '''可视化'''
    viz = visdom.Visdom()

    for epoch in range(1000):

        '''train'''
        for batchidx, (x, label_unuse) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)

            x_hat = model(x)
            loss = criteon(x_hat, x)

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch, 'loss:', loss.item())

        '''test'''
        x, label_unuse = iter(mnist_test).next()
        x = x.to(device)
        with torch.no_grad():
            x_hat = model(x)

        '''show test result'''
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))


if __name__ == '__main__':
    main()

VAE

实现变分自编码器网络结构

'''
定义变分自编码器网络结构
'''

import torch
from torch import nn


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # [b, 784] => [b, 20]
        # u: [b, 10]
        # sigma: [b, 10]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )

        # [b, 10] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

        self.criteon = nn.MSELoss()

    def forward(self, x):
        """

        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, 784)

        # encoder
        # [b, 20], including mean and sigma
        h_ = self.encoder(x)

        # [b, 20] => [b, 10] and [b, 10]
        mu, sigma = h_.chunk(2, dim=1)
        # reparametrize trick, epison~N(0, 1)
        h = mu + sigma * torch.randn_like(sigma)  # 随机抽样

        # decoder
        x_hat = self.decoder(h)

        # reshape
        x_hat = x_hat.view(batchsz, 1, 28, 28)

        kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)  # 计算与标准正态分布相比的散度

        return x_hat, kld

实现VAE对MNIST数据集的处理

'''
此处需要安装并开启visdom
安装:pip install visdom
开启:python -m visdom.server
'''

import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision import transforms, datasets

from vae import VAE

import visdom


def main():
    '''import data set'''
    mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    '''show data shape'''
    x, _ = iter(mnist_train).next()
    print('x:', x.shape)  # torch.Size([32, 1, 28, 28])

    '''def model'''
    device = torch.device('cuda')
    model = VAE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    viz = visdom.Visdom()

    for epoch in range(1000):
        '''train'''
        for batchidx, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)

            x_hat, kld = model(x)
            loss = criteon(x_hat, x)

            if kld is not None:
                elbo = - loss - 1.0 * kld
                loss = - elbo

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch, 'loss:', loss.item(), 'kld:', kld.item())

        '''test'''
        x, _ = iter(mnist_test).next()
        x = x.to(device)
        with torch.no_grad():
            x_hat, kld = model(x)

        '''show test result'''
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))


if __name__ == '__main__':
    main()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch——实现自编码器和变分自编码器 的相关文章

  • 大模型算法工程师的面试题来了(附答案)

    自 ChatGPT 在去年 11 月底横空出世 大模型的风刮了整一年 历经了百模大战 Llama 2 开源 GPTs 发布等一系列里程碑事件 将大模型技术推至无可争议的 C 位 基于大模型的研究与讨论 也让我们愈发接近这波技术浪潮的核心 最
  • 10个 Python 脚本来自动化你的日常任务

    在这个自动化时代 我们有很多重复无聊的工作要做 想想这些你不再需要一次又一次地做的无聊的事情 让它自动化 让你的生活更轻松 那么在本文中 我将向您介绍 10 个 Python 自动化脚本 以使你的工作更加自动化 生活更加轻松 因此 没有更多
  • 杂项:机器学习平台

    概述 机器学习学科正在快速扩展 因此 选择合适的机器学习平台至关重要 这有助于利用端到端方法成功构建模型 机器学习平台为用户提供了创建 实施和增强机器学习 尤其是机器学习算法 的工具 介绍 随着组织收集更多数据 使用机器学习和其他人工智能
  • 人工智能深度学习:探索智能的深邃奥秘

    导言 人工智能深度学习作为当今科技领域的明星 正引领着智能时代的浪潮 深度学习和机器学习作为人工智能领域的两大支柱 它们之间的关系既有协同合作 又存在着显著的区别 本文将深入研究深度学习在人工智能领域的角色 以及其在各行各业中的深远影响 研
  • 软件测试/测试开发/人工智能丨机器学习中特征的含义,什么是离散特征,什么是连续特征。

    在机器学习中 特征 Feature 是输入数据中的属性或变量 用于描述样本或数据点 特征对于机器学习模型而言是输入的一部分 模型通过学习样本的特征与其对应的标签 或输出 之间的关系来做出预测或分类 特征可以分为不同类型 其中两个主要的类型是
  • 软件测试/测试开发/人工智能丨机器学习中特征的含义,什么是离散特征,什么是连续特征。

    在机器学习中 特征 Feature 是输入数据中的属性或变量 用于描述样本或数据点 特征对于机器学习模型而言是输入的一部分 模型通过学习样本的特征与其对应的标签 或输出 之间的关系来做出预测或分类 特征可以分为不同类型 其中两个主要的类型是
  • 【YOLO算法训练数据集处理】缩放训练图片的大小,同时对图片的标签txt文件中目标的坐标进行同等的转换

    背景 在训练一个自己的yolo模型目标检测模型时 使用公共数据集时 通常要将图片缩放处理 而此时图片对应的标签文件中目标的坐标也应进行同等的变换 这样才能保证模型的正确训练 当然 如果是自建的数据集 则将图片进行缩放后 使用Labelimg
  • .h5文件简介

    一 简介 HDF5 Hierarchical Data Format version 5 是一种用于存储和组织大量数据的文件格式 它支持高效地存储和处理大规模科学数据的能力 HDF5 是一种灵活的数据模型 可以存储多种数据类型 包括数值数据
  • 基于生成式对抗网络的视频生成技术

    随着人工智能的快速发展 生成式对抗网络 GAN 作为一种强大的生成模型 已经在多个领域展现出了惊人的能力 其中 基于GAN的视频生成技术更是引起了广泛的关注 本文将介绍基于生成式对抗网络的视频生成技术的原理和应用 探索其对电影 游戏等领域带
  • 第二部分相移干涉术

    典型干涉图 相移干涉术 相移干涉术的优点 1 测量精度高 gt 1 1000 条纹 边缘跟踪仅为 1 10 边缘 2 快速测量 3 低对比度条纹测量结果良好 4 测量结果不受瞳孔间强度变化的影响 独立于整个瞳孔的强度变化 5 在固定网格点获
  • 机器学习---决策树

    介绍 决策树和随机森林都是非线性有监督的分类模型 决策树是一种树形结构 树内部每个节点表示一个属性上的测试 每个分支代表一个测试输出 每个叶子节点代表一个分类类别 通过训练数据构建决策树 可以对未知数据进行分类 随机森林是由多个决策树组成
  • 基于ResNet模型微调的自定义图像数据分类

    Import necessary packages import torch import torch nn as nn from torchvision import datasets models transforms from tor
  • 互操作性(Interoperability)如何影响着机器学习的发展?

    互操作性 Interoperability 也称为互用性 即两个系统之间有效沟通的能力 是机器学习未来发展中的关键因素 对于银行业 医疗和其他生活服务行业 我们期望那些用于信息交换的平台可以在我们需要时无缝沟通 我们每个人都有成千上万个数据
  • lr推荐模型 特征重要性分析

    在分析lr模型特征重要性之前 需要先明白lr模型是怎么回事儿 lr模型公式是sigmoid w1 x1 w2 x2 wn xn 其中w1 w2 wn就是模型参数 x1 x2 xn是输入的特征值 对于lr模型来说 特征可以分为两个粒度 一个是
  • 蒙牛×每日互动合作获评中国信通院2023“数据+”行业应用优秀案例

    当前在数字营销领域 品牌广告主越来越追求品效协同 针对品牌主更注重营销转化的切实需求 数据智能上市企业每日互动 股票代码 300766 发挥自身数据和技术能力优势 为垂直行业的品牌客户提供专业的数字化营销解决方案 颇受行业认可 就在不久前举
  • 如何用GPT制作PPT和写代码?

    详情点击链接 如何用GPT制作PPT和写模型代码 一OpenAI 1 最新大模型GPT 4 Turbo 2 最新发布的高级数据分析 AI画图 图像识别 文档API 3 GPT Store 4 从0到1创建自己的GPT应用 5 模型Gemin
  • 基于GPT4+Python近红外光谱数据分析及机器学习与深度学习建模

    详情点击链接 基于ChatGPT4 Python近红外光谱数据分析及机器学习与深度学习建模教程 第一 GPT4 基础 1 ChatGPT概述 GPT 1 GPT 2 GPT 3 GPT 3 5 GPT 4模型的演变 2 ChatGPT对话初
  • 【毕业设计选题】复杂背景下的无人机(UVA)夜间目标检测系统 python 人工智能 深度学习

    前言 大四是整个大学期间最忙碌的时光 一边要忙着备考或实习为毕业后面临的就业升学做准备 一边要为毕业设计耗费大量精力 近几年各个学校要求的毕设项目越来越难 有不少课题是研究生级别难度的 对本科同学来说是充满挑战 为帮助大家顺利通过和节省时间
  • 机器学习算法实战案例:时间序列数据最全的预处理方法总结

    文章目录 1 缺失值处理 1 1 统计缺失值 1 2 删除缺失值 1 3 指定值填充 1 4 均值 中位数 众数填充
  • 如何用GPT进行论文润色与改写?

    详情点击链接 如何用GPT GPT4进行论文润色与改写 一OpenAI 1 最新大模型GPT 4 Turbo 2 最新发布的高级数据分析 AI画图 图像识别 文档API 3 GPT Store 4 从0到1创建自己的GPT应用 5 模型Ge

随机推荐

  • 医学生可以跨专业考计算机的专业,可以跨考医学研究生:2016跨专业考研需谨慎的专业解读:临床医学...

    每年的跨专业考研人群有很大一批 或是因为本专业就业不景气 或是因为不感兴趣等等 诸多原因导致跨专业考研的人很多 跨专业考研的难度比一般要大 主要因为起点不同 往往此类考生专业课的基础都很低 从头开始 压力很大 因此在选专业的时候一定要谨慎
  • python怎么输出图片_Python怎么输出图片且不保存

    Python怎么输出图片且不保存 一 输出本地图片 使用open 函数来打开图片 使用show 函数来显示图片 from PIL import Image img Image open d dog png img show 这种图片显示方式
  • 基于BP神经网络的2014世界杯比分预测

    写在前头 科学的方法 娱乐的心态 研究背景 众所周知 今年的世界杯比赛各种坑爹 看了那么多砖家点评就没人说准过 当然足球比赛中有太多的未知变量 如何选择这些变量就成为了预测比赛比分的关键 本文作者另辟蹊径 选用足彩比分赔率作为影响比赛走势的
  • Java DAO代码重构(连接池方式)

    DAO设计简化思路 首先初始化数据库连接池 使用Alibaba的Druid连接池 需先下载druid 1 x x jar包 public class JDBCUtil private static DataSource ds null 初始
  • SQLServer数据库漏洞

    一 SQLServer数据库提权前提条件 1 以管理员身份运行数据库服务 2 已经获得SQL数据库的sysadmin权限 3 可以连接数据库 二 通过存储过程进行提权 hydra工具介绍 L 指定用户名字典 P 指定密码字典 vV 输出破解
  • 与孩子一起学编程python_与的解释

    子集上 一 与 康熙筆画 4 部外筆画 3 廣韻 集韻 正韻 同與 說文 賜予也 一勺爲与 六書正譌 寡則均 故从一勺 與 古文 廣韻 弋諸切 正韻 弋渚切 集韻 韻會 演女切 音予 說文 黨與也 戰國策 是君以合齊與强楚 註 與 黨與也
  • 《算法导论》笔记(18) 最大流 含部分习题

    流网络 容量值 源结点 汇点 容量限制 流量守恒 反平行 超级源结点 超级汇点 Ford Fulkerson方法 残存网络 增广路径 最小切割定理 f是最大流 残存网络不包含增广路径 f 等于最小切割容量三者等价 基本的Ford Fulke
  • Vijava 学习笔记之(获取用户自定义规范相关信息)

    源代码 package com vmware customzation import com vmware util Session import com vmware vim25 CustomizationSpecInfo import
  • [CVPR2020]Attention-Guided Hierarchical Structure Aggregation for Image Matting

    论文 Attention Guided Hierarchical Structure Aggregation for Image Matting 代码 wukaoliu CVPR2020 HAttMatting 基于注意力引导的层次结构聚集
  • mycat分库分表

    一 拆分原理 数据节点 分片 主机 ip port 数据库组合起来就是一个数据节点 分库 垂直拆分 不同的表分到不同的数据节点 分表 水平拆分 同一张表按照一定的规则拆分到不同的数据节点 二 mycat逻辑图 应用连接mycat mycat
  • 【编程之路】面试必刷TOP101:堆、栈、队列(42-49,Python实现)

    面试必刷TOP101 堆 栈 队列 42 49 Python实现 42 用两个栈实现队列 小试牛刀 step 1 push操作就正常push到第一个栈末尾 step 2 pop操作时 优先将第一个栈的元素弹出 并依次进入第二个栈中 step
  • 梦幻西游两个不同服务器的名字出现在跨服华山,系统会怎么处理,梦幻西游跨服决战华山玩法介绍...

    梦幻西游跨服决战华山新玩法已经出来了 很多的玩家还不知道该如何玩 下面我们一起来看详细的内容介绍 活动时间 没有帮派竞赛的周五 周日 进入活动场地时间 19 00 比赛时间 19 30 22 00 等级限制 等级 55级 分组机制 根据玩家
  • DLL,SDK,API专业技术术语

    SDK software development kit 中文可译为 软件开发工具包 一般都是一些被软件工程师用于为特定的软件包 软件架构 硬件平台 操作系统等建立应用软件的开发工具的集合 通俗点是指由第三方服务商提供的实现软件产品某项功能
  • 腾讯toB“联合舰队”的秘密

    14 天高强度谈判 每天都从早 8 点持续到凌晨 3 点 郭浩哲和他的同事们敲定了一笔融资 投资方是腾讯 投资金额达到了 12 66 亿元人民币 即使在腾讯的投资历史上 这都不是一个小数额 但实际流程仅用时一个多月 多少让郭浩哲对巨头的速度
  • Eclipse 安装C++环境

    安装CDT插件 方法一 选择 help 安装新的软件 然后点击Add 给定名称为 CDT 添加地址 http download eclipse org tools cdt releases juno 点击FInish 等待安装完成 提示重启
  • 第一课:初识Java语言

    第一课 初识Java语言 一 了解Java的历史由来 1 为什么学习Java编程语言 1 首先要了解编程语言的流行趋势 Tiobe PYPL排行榜 2 在这些排行榜上 Java语言的流行程度都名列前茅 在Tiobe排行榜上 甚至常年 排名第
  • 854. 相似度为 K 的字符串

    对于某些非负整数 k 如果交换 s1 中两个字母的位置恰好 k 次 能够使结果字符串等于 s2 则认为字符串 s1 和 s2 的 相似度为 k 给你两个字母异位词 s1 和 s2 返回 s1 和 s2 的相似度 k 的最小值 示例 1 输入
  • flea-jersey使用之文件上传接入

    文件上传 引言 1 客户端依赖 2 服务端依赖 3 文件上传接入讲解 3 1 服务端上传资源定义 3 2 服务端文件上传服务定义 3 3 客户端文件上传配置 3 4 客户端文件上传调用 引言 本文将要介绍 flea jersey 提供的文件
  • c++回调函数

    关于应用 1 创建struct结构体 typedef struct tag PixelCallBack AsynCall PixelCallBack 2 在 h 文件类中定义private 结构体变量 typedef void func c
  • pytorch——实现自编码器和变分自编码器

    文章目录 数据降维 主成分分析 PCA 自编码器 AE 变分自编码器 VAE pytorch实现 AE 实现自编码器网络结构 实现AE对MNIST数据集的处理 VAE 实现变分自编码器网络结构 实现VAE对MNIST数据集的处理 本文只设计