非线性激活

发布时间 2023-08-17 10:35:46作者: ydky

非线性激活

常见的非线性激活函数主要包括Sigmoid函数、tanh函数、ReLU函数、Leaky ReLU函数,可参考非线性激活函数,也可查看官网Non-linear Activations (weighted sum, nonlinearity)

非线性函数ReLU

torch.nn.ReLU(inplace=False)

$$
ReLU(x)=max(0,x)
$$

input = -1
Relu(input, inplace=True)
input = 0

input = -1
output = Relu(input, inplace=False)
input = -1
output = 0
inplace表示是否将输出赋值给输入,默认False,一般使用默认值来保留输入值

import torch
from torch import nn
from torch.nn import ReLU

input = torch.tensor([[1, -0.5],
                      [-1, 3]])

print(input)
print(input.shape)

class Baserelu(nn.Module):
    def __init__(self):
        super(Baserelu, self).__init__()
        self.relu1 = ReLU()

    def forward(self, input):
        output = self.relu1(input)
        return output

baserelu = Baserelu()
output = baserelu(input)
print(output)

非线性函数Sigmoid

torch.nn.Sigmoid(*args, **kwargs)

# 以CIFAR10作为数据集
import torchvision.datasets
from torch import nn
from torch.nn import Sigmoid
from torch.utils.data import DataLoader

from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10("./dataset2", train=False, transform=torchvision.transforms.ToTensor(), download=True)

dataloader = DataLoader(dataset, batch_size=64)

class Basesigmoid(nn.Module):
    def __init__(self):
        super(Basesigmoid, self).__init__()
        self.sigmoid1 = Sigmoid()

    def forward(self, input):
        output = self.sigmoid1(input)
        return output

basesigmoid = Basesigmoid()

writer = SummaryWriter("logs")

step = 0
for data in dataloader:
    imgs, targets = data
    # 注意是add_images
    writer.add_images("sigmoid_input", imgs, step)
    output =basesigmoid(imgs)
    writer.add_images("sigmoid_output", output, step)
    step = step + 1

writer.close()