pytorch笔记12--无监督的AutoEncoder(自编码)

2023-10-26

1. AutoEncoder: 给特征属性降维

2. Data---->压缩(提取Data的关键信息,减小网络的运算压力)---->data(具有代表性的特征)---->解压(还原数据信息)---->Pred_Data

3. 使用Mnist数据集训练,将数据先压缩再解压,并用训练集的前5张图片可视化训练的过程,过程图和结果图如下:

可视化训练集前200张图片的预测类别结果:

# 使用MNIST数据集先压缩再解压,用压缩的特征进行分监督分类   (无监督)
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np

# Hyper parameters
EPOCH=10
BATCH_SIZE=64
LR=0.005
DOWNLOAD_MNIST=False

# mnist dataset
train_data=torchvision.datasets.MNIST(
    root='./mnist/',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST
)
train_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
# AutoEncoder: encoder & decoder, 压缩后得到压缩的特征值,再从压缩的特征值中解压出原图
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder=nn.Sequential(
            nn.Linear(28*28,128),
            nn.Tanh(),
            nn.Linear(128,64),
            nn.Tanh(),
            nn.Linear(64,12),
            nn.Tanh(),
            nn.Linear(12,3)    #压缩为3个特征(3d),进行3D图像的可视化
        )
        self.decoder=nn.Sequential(
            nn.Linear(3,12),
            nn.Tanh(),
            nn.Linear(12,64),
            nn.Tanh(),
            nn.Linear(64,128),
            nn.Tanh(),
            nn.Linear(128,28*28),
            nn.Sigmoid() # 让输出值在 (0,1)
        )
    def forward(self,x):
        encoded=self.encoder(x)
        decoded=self.decoder(encoded)
        return encoded,decoded    # 返回压缩后的结果 和 解压后的结果

autoencoder=AutoEncoder()

# training
optimizer=torch.optim.Adam(autoencoder.parameters(),lr=LR)
loss_func=nn.MSELoss()

f,a=plt.subplots(2,5,figsize=(5,2))
plot_data=train_data.data[:5].view(-1,28*28).type(torch.FloatTensor)/255   # 训练过程中显示的图片
for i in range(5):
    a[0][i].imshow(np.reshape(plot_data.data.numpy()[i],(28,28)),cmap='gray')
    a[0][i].set_xticks(())     # 是刻度不显示
    a[0][i].set_yticks(())

for epoch in range(EPOCH):
    for step,(x,label) in enumerate(train_loader):
        b_x=x.view(-1,28*28)   # reshape x to(batch,28*28)
        b_y=x.view(-1,28*28)

        encoded,decoded=autoencoder(b_x)

        loss=loss_func(decoded,b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step%100==0:    # 每100步更新一次解压的图片
            print('Epoch: ',epoch,'| train loss: %.4f'%loss.data.numpy())

            # 显示解压过程中的图片变化(第一行是原图,第二行是训练过程中的图片)
            _,decoded_data=autoencoder(plot_data)
            for i in range(5):
                a[1][i].clear()
                a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i],(28,28)),cmap='gray')
                a[1][i].set_xticks(())
                a[1][i].set_yticks(())
            plt.draw()
            plt.pause(0.05)
plt.show()

# 可视化效果
view_data=train_data.data[:200].view(-1,28*28).type(torch.FloatTensor)/255
encoded_data,_=autoencoder(view_data)   # 提取压缩的特征值
fig=plt.figure(2)
ax=Axes3D(fig)

# X,Y,Z: 图片压缩后的3个特征值
X=encoded_data.data[:,0].numpy()
Y=encoded_data.data[:,1].numpy()
Z=encoded_data.data[:,2].numpy()
labels=train_data.targets[:200].numpy()
for x,y,z,lbl in zip(X,Y,Z,labels):
    c=cm.rainbow(int(255*lbl/9))            #上色  0~9
    ax.text(x,y,z,lbl,fontdict={'color':c})

ax.set_xlim(X.min(),X.max())
ax.set_ylim(Y.min(),Y.max())
ax.set_zlim(Z.min(),Z.max())
plt.show()

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch笔记12--无监督的AutoEncoder(自编码) 的相关文章

随机推荐

  • 【pytorch】pytorch模型保存技巧

    Pytorch会把模型相关信息保存为一个字典结构的数据 以用于继续训练或者推理 1 保存与加载模型参数 这是最常见的模型保存与加载方式 保存方式如下 state model state dict torch save state xxx p
  • qml实现红绿灯切换功能

    题目要求 参考代码 https download csdn net download y478225902 5260541 实现源码 import QtQuick 2 12 import QtQuick Window 2 12 Window
  • springboot整合maven Profile实现properties文件多环境配置

    步骤 首先写几个properties的配置文件 一般这样的文件有三个 而且文件的名称也也可以随意 不论你们的项目是使用的springmvc还是springboot 文件名称都可以随意指定 例如我的几个文件 在文件中写一些测试的属性值 方便测
  • 【一】重温HTML

    引言 经典对答 面试官 你了解HTML吗 回答 啊 我是来面试前端的呀 我会Vue 面试官 写文思考 写这一系列文章的时候 自己思考了几个问题 HTML的文章太多了 为什么还要写 HTML的入门谁不会 还要学 HTML的文章基本都是水文 谁
  • ES6解构赋值

    前面的话 我们经常定义许多对象和数组 然后有组织地从中提取相关的信息片段 在ES6中添加了可以简化这种任务的新特性 解构 解构是一种打破数据结构 将其拆分为更小部分的过程 本文将详细介绍ES6解构赋值 引入 在ES5中 开发者们为了从对象和
  • Mysql中MVCC的使用及原理详解

    准备 测试环境 Mysql 5 7 20 log 数据库默认隔离级别 RR Repeatable Read 可重复读 MVCC主要适用于Mysql的RC RR隔离级别 创建一张存储引擎为testmvcc的表 sql为 CREATE TABL
  • error compiling template但编辑器内未报错,处理步骤。

    1 首先寻找自己所引入的组件当中 例如用到了某个方法 而自己没有把方法写上 2 寻找自己所引入的代码当中是否有重复的代码 可能是复制的时候多复制一行而导致的 3 寻找是否有空格所导致的error compiling template 报错
  • 到处是“坑”的strtok()—解读strtok()的隐含特性

    在用C C 实现字符串处理逻辑时 strtok函数的使用非常广泛 其主要作用是按照给定的字符集分隔字符串 并返回各子字符串 由于该函数的使用有诸多限制 如果使用不当就会造成很多 坑 因此本文首先介绍那些经常误踩的坑 然后通过分析源代码 解读
  • Android——第三方Facebook授权登录获取用户信息

    由于项目中需要使用Facebook进行一键登录 所以记录下步骤 其实小伙伴直接看官网也可以 介绍的蛮详细的 先看下效果图吧 遵循以下步骤将Facebook登录添加到您的应用 Facebook开发者网站 https developers fa
  • bin文件转成C语言数组之c代码

    反汇编的时候用的着 include
  • Js弹出showModalDialog窗口---返回值或数组

    function showMyModalDialog url width height showModalDialog url dialogWidth width px dialogHeight height px center yes s
  • ACwing :01背包问题

    朴素的 动规的 基本表示 f i j 表示只看前 i 个物品 总体积是 j 的情况下 总价值最大是多少 result max f n 0 V f i j 1 不选第 i 个物品 f i j f i 1 j 2 选第 i个物品 f i j f
  • ubuntu 如何使用 root 用户

    环境 virtual box 6 1 ubuntu 1604 LTS 64 问题 一般的ubuntu会创建一个管理员用户 在使用 su 指令从管理员切换到root用户后 设在 etc profile的环境变量丢失 如何才能保证环境变量不变呢
  • Android开发中怎么实现上传图片到服务器

    要实现在Android开发中上传图片到服务器 可以按照以下步骤进行 1 在Android项目中添加相应的权限 确保应用程序可以访问设备上的照片或相机 在 AndroidManifest xml 文件中添加以下权限
  • linux服务端下的c++ udp socket demo

    linux服务端 udp socket demo 如下 创建接受数据的socket int iSock socket PF INET SOCK DGRAM 0 printf socket ss d n iSock struct sockad
  • 三种基于CUDA的归约计算

    归约在并行计算中很常见 并且在实现上具有一定的套路 本文分别基于三种机制 Intrinsic 共享内存 atomic 实现三个版本的归约操作 完成一个warp 32 大小的整数数组的归约求和计算 Intrinsic版本 基于Intrinsi
  • 网站视频服务器架设,云服务器架设网站视频教程

    云服务器架设网站视频教程 内容精选 换一换 安装MySQL本文档以 CentOS 6 5 64bit 40GB 操作系统为例 对应MySQL版本为5 1 73 CentOS 7及以上版本将MySQL数据库软件从默认的程序列表中移除 需执行s
  • Keil常见错误警告

    1 warning 767 D conversion from pointer to smaller integer 解释 将指针转换为较小的整数 影响 可能造成的影响 容易引起数据截断 造成不必要的数据丢失 如果出现 bug 很难 调试
  • mybatis 的mapper接口注入到spring 容器的源码解析

    一 环境准备 1 创建一个maven 项目 其POM文件如下
  • pytorch笔记12--无监督的AutoEncoder(自编码)

    1 AutoEncoder 给特征属性降维 2 Data gt 压缩 提取Data的关键信息 减小网络的运算压力 gt data 具有代表性的特征 gt 解压 还原数据信息 gt Pred Data 3 使用Mnist数据集训练 将数据先压