Keras Flatten

发布时间 2023-10-08 09:30:36作者: emanlee

Keras Flatten

===============================================================

作用:

Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小。

例子:

    from keras.models import Sequential
    from keras.layers.core import Flatten
    from keras.layers.convolutional import Convolution2D
    from keras.utils.vis_utils import plot_model
     
     
    model = Sequential()
    model.add(Convolution2D(64,3,3,border_mode="same",input_shape=(3,32,32)))
    # now:model.output_shape==(None,64,32,32)
     
    model.add(Flatten())
    # now: model.output_shape==(None,65536)
     
    plot_model(model, to_file='Flatten.png', show_shapes=True)

为了更好的理解Flatten层作用,我把这个神经网络进行可视化如下图:

 

 


链接:https://blog.csdn.net/program_developer/article/details/80853425

===============================================================

keras.layers.Flatten(data_format=None)

data_format:一个字符串,其值为 channels_last(默认值)或者 channels_first。它表明输入的维度的顺序。此参数的目的是当模型从一种数据格式切换到另一种数据格式时保留权重顺序。channels_last 对应着尺寸为 (batch, ..., channels) 的输入,而 channels_first 对应着尺寸为 (batch, channels, ...) 的输入。默认为 image_data_format 的值,你可以在 Keras 的配置文件 ~/.keras/keras.json 中找到它。如果你从未设置过它,那么它将是 channels_last
model = Sequential()
model.add(Conv2D(64, (3, 3),
                 input_shape=(3, 32, 32), padding='same',))
# 现在:model.output_shape == (None, 64, 32, 32)

model.add(Flatten())
# 现在:model.output_shape == (None, 65536)

 

 

===============================================================

如果我们考虑创建的原始模型(具有Flatten层),则可以得到以下模型摘要:

1
2
3
4
5
6
7
8
9
10
11
12
13
Layer (type)                 Output Shape              Param #  
=================================================================
D16 (Dense)                  (None, 3, 16)             48        
_________________________________________________________________
A (Activation)               (None, 3, 16)             0        
_________________________________________________________________
F (Flatten)                  (None, 48)                0        
_________________________________________________________________
D4 (Dense)                   (None, 4)                 196      
=================================================================
Total params: 244
Trainable params: 244
Non-trainable params: 0

对于此摘要,下一张图像有望对每一层的输入和输出大小提供更多的了解。

可以读取的Flatten层的输出形状为(None, 48)。这里是提示。您应该阅读(1, 48)或(2, 48)或...或(16, 48) ...或(32, 48),...

实际上,该位置上的None表示任何批量大小。对于召回的输入,第一维表示批次大小,第二维表示输入要素的数量。

在Keras中Flatten层的作用非常简单:

对张量进行展平操作会将张量整形,使其形状等于张量中包含的元素数量(不包括批尺寸)。

enter image description here

注意:我使用了model.summary()方法来提供输出形状和参数详细信息。

 

===============================================================