基于pytorch实现的Auto-encoder模型

2023-11-13

最近因为在自己论文当中可能要用到Auto-encoder 这个东西,学了点皮毛之后想着先按照别人的解释实现一下,然后在MNIST数据集上跑了下测试看看效果。
话不多说直接贴代码。

"""
Author:Media
2020-10-23
"""
import torch
import torch.nn as nn
import torch.utils.data as Data
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data_root):
        self.data = data_root
        # self.label = data_label

    def __getitem__(self, index):
        data = self.data[index]
        # labels = self.label[index]
        return data  # , labels

    def __len__(self):
        return len(self.data)


# 超参数
# DATA_DIM = 10
EPOCH = 10
BATCH_SIZE = 64
LR = 0.005
BIAS = 0.05
EPOCHS = 10
SAMPLE_SIZE = 10
FILEPATH = ""

def read_csv_file_data(file_path):  # read .csv file
    data = pd.read_csv(file_path)
    train_data = np.array(data, dtype=np.float32)  # np.ndarray()
    train_x_list = torch.from_numpy(train_data)  # list
    return train_x_list


def read_txt_file_data(filepath):  # read .txt file
    data = list()
    for line in open(filepath, 'r'):
        temp = torch.zeros(784)
        tt = line.split(' ')[:-1]
        for item in tt:
            content = item.split(':')
            temp[int(content[0])] = float(content[1])
        data.append(temp)
    return data[10:len(data)-10]


DATA_DIM = 784
HIDE_DIM = 64
traindata = read_txt_file_data(FILEPATH)
train_data = MyDataset(traindata)
trainLoader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)


class Auto_Encoder(nn.Module):
    def __init__(self, _input_dim, _hide_dim):
        super(Auto_Encoder, self).__init__()
        self.input_dim = _input_dim
        self.hide_dim = _hide_dim
        self.encoder = Encoder(_input_dim=self.input_dim, _hide_dim=self.hide_dim)
        self.decoder = Decoder(_input_dim=self.input_dim, _hide_dim=self.hide_dim)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

    def output(self, x):
        return self.encoder(x)


class Encoder(nn.Module):
    def __init__(self, _input_dim, _hide_dim):
        super(Encoder, self).__init__()
        self.input_dim = _input_dim
        self.hide_dim = _hide_dim
        self.linear1 = nn.Linear(_input_dim, 512)
        self.linear2 = nn.Linear(512, 256)
        self.linear3 = nn.Linear(256, 128)
        self.linear4 = nn.Linear(128, self.hide_dim)

    def forward(self, x):
        x = torch.tanh(self.linear1(x))
        x = torch.tanh(self.linear2(x))
        x = torch.tanh(self.linear3(x))
        x = self.linear4(x)
        return x


class Decoder(nn.Module):
    def __init__(self, _input_dim, _hide_dim):
        super(Decoder, self).__init__()
        self.input_dim = _input_dim
        self.hide_dim = _hide_dim
        self.linear1 = nn.Linear(_hide_dim, 128)
        self.linear2 = nn.Linear(128, 256)
        self.linear3 = nn.Linear(256, 512)
        self.linear4 = nn.Linear(512, self.input_dim)

    def forward(self, x):
        x = torch.tanh(self.linear1(x))
        x = torch.tanh(self.linear2(x))
        x = torch.tanh(self.linear3(x))
        x = torch.sigmoid(self.linear4(x))
        return x


def draw_mnist(data, title="raw data"):
    data = np.array(data)
    img = data.reshape(28, 28)
    plt.title(title)
    plt.imshow(img, cmap='gray')
    plt.show()


autoencoder = Auto_Encoder(_input_dim=DATA_DIM, _hide_dim=HIDE_DIM)
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func = nn.MSELoss()


def learn_by_epoch(epochs):
    epoch = 0
    while epoch < epochs:
        for _, x in enumerate(trainLoader):
            x = torch.tensor(x)
            # y = x
            encoded, decoded = autoencoder(x)
            loss = loss_func(decoded, x)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if epoch % 100 == 0:
            print('epoch:' + str(epoch) + ' = ' + str(loss.data.item()))
        epoch += 1


def learn_by_bias(bias):
    epochs = 0
    count = 0
    while count < 5:
        for _, x in enumerate(trainLoader):
            x = torch.tensor(x)
            y = x
            encoded, decoded = autoencoder(x)
            loss = loss_func(decoded, y)
            if loss < bias:
                count += 1
            else:
                count = 0
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        # if epochs % 100 == 0:
        print('epoch:' + str(epochs) + ' = ' + str(loss.data.item()))
        epochs += 1
    print("train time:= "+str(epochs))


learn_by_epoch(epochs=EPOCHS)
# learn_by_bias(bias=BIAS)
result = []
indices = np.random.choice(len(traindata), SAMPLE_SIZE)
for item in indices:
    # print("input:= "+str(item))
    item = traindata[item].unsqueeze(0)
    _, tt,  = autoencoder(item)
    tt = tt.detach()
    tt = torch.squeeze(tt)
    result.append(tt.numpy())

index = 0
for item in indices:
    draw_mnist(traindata[item])
    draw_mnist(result[index], "auto encoder out")
    index += 1
print(index)

代码中使用的数据集是稀疏存储版的MNIST数据。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于pytorch实现的Auto-encoder模型 的相关文章

随机推荐

  • [MATLAB]Jacobi迭代

    MATLAB代码 关于使用雅可比迭代法求线性方程组的数值解 jacobi m 定义Jacobi迭代函数 function x n jacobi A b x0 eps 计算迭代矩阵 D diag diag A L tril A 1 U tri
  • Docker入门到实践 (六) docker网络模式详解以及容器间的网络通信

    文章目录 一 前言 二 docker网络模式介绍 1 默认网络 1 1 bridge网络模式 1 2 host网络模式 1 3 none网络模式 1 4 container网络模式 2 自定义网络 2 1 创建网络 2 2 连接网络 2 3
  • 微软收购暴雪的野心:与索尼争雄 重金布局元宇宙

    1月18日 微软发布声明称 将以全现金方式斥资687亿美元收购游戏巨头动视暴雪 这将成为微软有史以来规模最大的一笔收购 同时也将改写游戏行业的收购纪录 完成这笔收购之后 使命召唤 魔兽世界 糖果传奇 暗黑破坏神 守望先锋 等脍炙人口的作品将
  • element-ui el-cascader 级联选择器 联动默认值

    在使用 element ui 的 el cascader 组件根据后台返回的数据 需要展示一个默认值 官网给出的例子https element eleme cn 2 0 zh CN component cascader 借鉴了一下 话不多说
  • hexo博客搭建-背景知识(二)

    yum与rpm的区别 rpm适用于所有环境 而yum要搭建本地yum源才可以使用 yum是上层管理工具 自动解决依赖性 而rpm是底层管理工具 gcc cc c g 命令行详解 gcc包含的c c 编译器 gcc cc c g gcc和cc
  • JDK8 网络Net包研究(一)

    网络基础 1 国际标准化组织的OSI 开放式系统互联模型 七层模型 2 TCP IP协议 组 四层模型 3 TCP IP协议组 一组包括TCP协议和IP协议 UDP协议 ICMP协议和其他一些协议的协议组 网络层 IP协议 gt 网络互连协
  • sqlserver存储过程基本语法

    转载自 sqlserver存储过程的基本语法 1 定义变量 简单赋值 declare a int set a 5 print a 使用select语句赋值 declare user1 nvarchar 50 select user1 张三
  • ElasticSearch——全文检索

    ElasticSearch 全文检索 来源 尚硅谷 谷粒商城高级篇 一 简介 官网 https www elastic co cn what is elasticsearch 全文搜索属于最常见的需求 开源的 Elasticsearch 是
  • TypeScript学习(一):快速入门

    文章目录 一 TypeScript 简介 1 TypeScript 是什么 2 TypeScript 与 JavaScript 的区别 3 JavaScript 的缺点 4 为什么使用 TypeScript 二 TypeScript 开发环
  • 软件设计命名规范

    1 命名约定 Pascal和Camel命名约定 编程的命名方式主要有Pascal和Camel两种 Pascal 每个单词的首字母大写 例如ProductType Camel 首个单词的首字母小写 其余单词的首字母大写 例如productTy
  • IDA使用之旅(一)用IDA查看最简单的sys文件

    转载请标明是引用于 http blog csdn net chenyujing1234 欢迎大家拍砖 本系列内容是我根据 知其所以然论坛 博主录制的学习视频 做的笔记 使用的IDA软件版本 IDA pro 5 5 参考下载地址 http w
  • 使用Maven插件整合protocol buffer

    本来自己在网上找如何使protocol buffer在IDE 我用的是IDEA 上使用的 结果搜索出来的都不尽人意 因为都太粗略了 没有重点的去阐述 所以最后还是决定自己搜索相关的Maven插件 再慢慢地摸索 费了我好多的时间啊 本人小白
  • gojs 流程图框架-节点装饰器模板(二)

    上一章我们了解了如何使用 gojs 完成基本的节点和连接线的绘制 gojs 中还可以对节点或边进行自由拖动 编辑等功能 本章将基于上一章编写的流程图代码 为这些节点设置装饰器模板 完成后的效果图 建议下载源码 对照本文进行学习 源码地址 g
  • 【11月比赛合集】13场可报名的创新应用、数据分析和程序设计大奖赛,任君挑选!

    CompHub 实时聚合多平台的数据类 Kaggle 天池 和OJ类 Leetcode 牛客 比赛 本账号同时会推送最新的比赛消息 欢迎关注 更多比赛信息见 CompHub主页 或 点击文末阅读原文 以下信息仅供参考 以比赛官网为准 目录
  • 性能优化:虚拟列表,如何渲染10万条数据的dom,页面同时不卡顿

    最近做的一个需求 当列表大概有2万条数据 又不让做成分页 如果页面直接渲染2万条数据 在一些低配电脑上可能会照成页面卡死 基于这个需求 我们来手写一个虚拟列表 思路 列表中固定只显示少量的数据 比如60条 在列表滚动的时候不断的去插入删除d
  • GMP初探

    G Goroutine 协程 用户级的轻量级线程 M 对内核线程的封装 P 为G和M的调度对象 主要用途是用来执行goroutine 维护了一个goroutine队列 即runqueue 由来 单进程时代 这个时代不需要调度器 早起的操作系
  • PMS-adb install安装应用流程(Android L)

    第一次画流程图画的不好 通过adb install安装应用时对framework来说会首先调用Pm java的runInstall 方法 private int runInstall int installFlags 0 int userI
  • mesa调试技巧

    技术关键字 mesa log系统 环境变量 目录 前言 一 gdb或vscode的断点调试 二 mesa log 系统的使用 总结 前言 软件调试技术是要求软件开发人员必备的一项技能 不同的问题具有不同的调试手段和方法 本文从mesa库的实
  • xcode报错:Cycle inside *******

    xcode报错 Cycle inside building could produce unreliable results This usually can be resolved by moving the target s Heade
  • 基于pytorch实现的Auto-encoder模型

    最近因为在自己论文当中可能要用到Auto encoder 这个东西 学了点皮毛之后想着先按照别人的解释实现一下 然后在MNIST数据集上跑了下测试看看效果 话不多说直接贴代码 Author Media 2020 10 23 import t