最近因为在自己论文当中可能要用到Auto-encoder 这个东西,学了点皮毛之后想着先按照别人的解释实现一下,然后在MNIST数据集上跑了下测试看看效果。
话不多说直接贴代码。
"""
Author:Media
2020-10-23
"""
import torch
import torch.nn as nn
import torch.utils.data as Data
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
class MyDataset(torch.utils.data.Dataset):
def __init__(self, data_root):
self.data = data_root
# self.label = data_label
def __getitem__(self, index):
data = self.data[index]
# labels = self.label[index]
return data # , labels
def __len__(self):
return len(self.data)
# 超参数
# DATA_DIM = 10
EPOCH = 10
BATCH_SIZE = 64
LR = 0.005
BIAS = 0.05
EPOCHS = 10
SAMPLE_SIZE = 10
FILEPATH = ""
def read_csv_file_data(file_path): # read .csv file
data = pd.read_csv(file_path)
train_data = np.array(data, dtype=np.float32) # np.ndarray()
train_x_list = torch.from_numpy(train_data) # list
return train_x_list
def read_txt_file_data(filepath): # read .txt file
data = list()
for line in open(filepath, 'r'):
temp = torch.zeros(784)
tt = line.split(' ')[:-1]
for item in tt:
content = item.split(':')
temp[int(content[0])] = float(content[1])
data.append(temp)
return data[10:len(data)-10]
DATA_DIM = 784
HIDE_DIM = 64
traindata = read_txt_file_data(FILEPATH)
train_data = MyDataset(traindata)
trainLoader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
class Auto_Encoder(nn.Module):
def __init__(self, _input_dim, _hide_dim):
super(Auto_Encoder, self).__init__()
self.input_dim = _input_dim
self.hide_dim = _hide_dim
self.encoder = Encoder(_input_dim=self.input_dim, _hide_dim=self.hide_dim)
self.decoder = Decoder(_input_dim=self.input_dim, _hide_dim=self.hide_dim)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return encoded, decoded
def output(self, x):
return self.encoder(x)
class Encoder(nn.Module):
def __init__(self, _input_dim, _hide_dim):
super(Encoder, self).__init__()
self.input_dim = _input_dim
self.hide_dim = _hide_dim
self.linear1 = nn.Linear(_input_dim, 512)
self.linear2 = nn.Linear(512, 256)
self.linear3 = nn.Linear(256, 128)
self.linear4 = nn.Linear(128, self.hide_dim)
def forward(self, x):
x = torch.tanh(self.linear1(x))
x = torch.tanh(self.linear2(x))
x = torch.tanh(self.linear3(x))
x = self.linear4(x)
return x
class Decoder(nn.Module):
def __init__(self, _input_dim, _hide_dim):
super(Decoder, self).__init__()
self.input_dim = _input_dim
self.hide_dim = _hide_dim
self.linear1 = nn.Linear(_hide_dim, 128)
self.linear2 = nn.Linear(128, 256)
self.linear3 = nn.Linear(256, 512)
self.linear4 = nn.Linear(512, self.input_dim)
def forward(self, x):
x = torch.tanh(self.linear1(x))
x = torch.tanh(self.linear2(x))
x = torch.tanh(self.linear3(x))
x = torch.sigmoid(self.linear4(x))
return x
def draw_mnist(data, title="raw data"):
data = np.array(data)
img = data.reshape(28, 28)
plt.title(title)
plt.imshow(img, cmap='gray')
plt.show()
autoencoder = Auto_Encoder(_input_dim=DATA_DIM, _hide_dim=HIDE_DIM)
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func = nn.MSELoss()
def learn_by_epoch(epochs):
epoch = 0
while epoch < epochs:
for _, x in enumerate(trainLoader):
x = torch.tensor(x)
# y = x
encoded, decoded = autoencoder(x)
loss = loss_func(decoded, x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print('epoch:' + str(epoch) + ' = ' + str(loss.data.item()))
epoch += 1
def learn_by_bias(bias):
epochs = 0
count = 0
while count < 5:
for _, x in enumerate(trainLoader):
x = torch.tensor(x)
y = x
encoded, decoded = autoencoder(x)
loss = loss_func(decoded, y)
if loss < bias:
count += 1
else:
count = 0
optimizer.zero_grad()
loss.backward()
optimizer.step()
# if epochs % 100 == 0:
print('epoch:' + str(epochs) + ' = ' + str(loss.data.item()))
epochs += 1
print("train time:= "+str(epochs))
learn_by_epoch(epochs=EPOCHS)
# learn_by_bias(bias=BIAS)
result = []
indices = np.random.choice(len(traindata), SAMPLE_SIZE)
for item in indices:
# print("input:= "+str(item))
item = traindata[item].unsqueeze(0)
_, tt, = autoencoder(item)
tt = tt.detach()
tt = torch.squeeze(tt)
result.append(tt.numpy())
index = 0
for item in indices:
draw_mnist(traindata[item])
draw_mnist(result[index], "auto encoder out")
index += 1
print(index)
代码中使用的数据集是稀疏存储版的MNIST数据。