torch官网:torch.nn — PyTorch 1.11.0 documentation
非线性变换的主要目的就是给网中加入一些非线性特征,非线性越多才能训练出符合各种特征的模型。常见的非线性激活:
ReLU:
![](https://img-blog.csdnimg.cn/5fd53830b2ce43d8b1939fd2fa900c53.png)
官网给出的例子:
>>> m = nn.ReLU()
>>> input = torch.randn(2)
>>> output = m(input)
An implementation of CReLU - https://arxiv.org/abs/1603.05201
>>> m = nn.ReLU()
>>> input = torch.randn(2).unsqueeze(0)
>>> output = torch.cat((m(input),m(-input)))
自己创建一个实现代码 其中会输如会出现 inplace
![](https://img-blog.csdnimg.cn/672916d701334873af200bd84254a355.png)
例如inplace=false
import torch
from torch import nn
from torch.nn import ReLU
input=torch.tensor([[1,-0.5],
[-1,3]])
# 改变矩阵size 1维2*2
input=torch.reshape(input,(-1,1,2,2))
print(input.shape)
class LL(nn.Module):
def __init__(self):
super(LL,self).__init__()
self.relu1=ReLU(inplace=False)
def forward(self,input):
output=self.relu1(input)
return output
ll=LL()
output=ll(input)
print(output)
输出结果:
![](https://img-blog.csdnimg.cn/01f30e4638c54a34a5b2022d9cb25fce.png)
SIGMOID:
![](https://img-blog.csdnimg.cn/4fe17ffffbf14878a79dcb9ed43f91d8.png)
![](https://img-blog.csdnimg.cn/e40e0284c3624fcfab389b8e29400372.png)
官网给出的例子
m = nn.Sigmoid()
input = torch.randn(2)
output = m(input)
自己创建一个实现代码(输入为CIFAR10数据集):
import torch
import torchvision
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset=torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),
download=True)
dataloader =DataLoader(dataset,batch_size=64)
class LL(nn.Module):
def __init__(self):
super(LL,self).__init__()
# self.relu1=ReLU(inplace=True)
self.sigmoid1=Sigmoid()
def forward(self,input):
output = self.sigmoid1(input)
return output
ll=LL()
writer = SummaryWriter("nn.sigmoid")
step=0
for data in dataloader:
imgs,targets =data
writer.add_images("input",imgs,global_step=step)
output = ll(imgs)
writer.add_images("output",output,global_step=step)
step=step+1
writer.close()
输出结果:
![](https://img-blog.csdnimg.cn/84124d702a4f45c39dec3fdb32a109c4.png)
![](https://img-blog.csdnimg.cn/fd7a206e0db84653871c608a6c2bdec2.png)