python - 如何绕过 TensorFlow 中部分神经网络的某些(但不是全部)功能

标签 python tensorflow keras

在我的 TensorFlow 模型中,我将一些数据输入到一堆 CNN 中,然后再进入几个完全连接的层。我已经使用 Keras 的 Sequential 模型实现了这一点。但是,我现在有一些数据不应进入 CNN,而应直接输入第一个全连接层,因为该数据包含一些属于输入数据一部分的值和标签,但该数据不应按原样进行卷积不是图像数据。

这样的事情可以用 tensorflow.keras 实现吗?还是我应该用 tensorflow.nn 来实现?据我了解 Keras' sequential models就是输入从一端进入,从另一端出来,中间没有特殊的接线。

为了做到这一点,我必须对最后一个 CNN 层的数据以及绕过 CNN 的数据使用 tensorflow.concat,然后再将其输入到第一个全连接层,这对吗?

最佳答案

这是一个简单的示例,其中的操作是将来自不同子网的激活相加:

import keras
import numpy as np
import tensorflow as tf
from keras.layers import Input, Dense, Activation

tf.reset_default_graph()

# this represents your cnn model 
def nn_model(input_x):
    feature_maker = Dense(10, activation='relu')(input_x)
    feature_maker = Dense(20, activation='relu')(feature_maker)
    feature_maker = Dense(1, activation='linear')(feature_maker)
    return feature_maker

# a list of input layers, of course the input shapes can be different
input_layers = [Input(shape=(3, )) for _ in range(2)]
coupled_feature = [nn_model(input_x) for input_x in input_layers]

# assume you take the sum of the outputs 
coupled_feature = keras.layers.Add()(coupled_feature)
prediction = Dense(1, activation='relu')(coupled_feature)

model = keras.models.Model(inputs=input_layers, outputs=prediction)
model.compile(loss='mse', optimizer='adam')

# example training set
x_1 = np.linspace(1, 90, 270).reshape(90, 3)
x_2 = np.linspace(1, 90, 270).reshape(90, 3)
y = np.random.rand(90)

inputs_x = [x_1, x_2]

model.fit(inputs_x, y, batch_size=32, epochs=10)

您实际上可以绘制模型以获得更多直觉

from keras.utils.vis_utils import plot_model

plot_model(model, show_shapes=True)

上述代码的模型如下所示

model

关于python - 如何绕过 TensorFlow 中部分神经网络的某些(但不是全部)功能,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58318195/

相关文章:

tensorflow - TensorBoard 不显示所有数据点

python - TensorFlow Tensor 在 numpy argmax 与 keras argmax 中的处理方式不同

python - 图像边缘检测 Keras 模型损失没有改善

python - “<' not supported between instances of ' 方法”和 'method' - Python、Django

python - 如何使用 solve_ivp 通过精确点?

python - 从字符串中删除数字

python - 保存文本分类模型后获取真实的类标签

python - kivy python 3.x循环添加小部件.kv

python - 如何将 tf.while_loop() 用于 tensorflow 中的可变长度输入?

c++ - c++ 中的 readNetFromTensorflow 错误