python - 如何在 Tensorflow 中加载预训练的 LSTM 模型权重

标签 python tensorflow lstm pre-trained-model

我想在 Tensorflow 中实现一个带有预训练权重的 LSTM 模型。这些权重可能来自 Caffee 或 Torch。
我发现 rnn_cell.py 文件中有 LSTM 单元,例如 rnn_cell.BasicLSTMCellrnn_cell.MultiRNNCell。但是我如何才能为这些 LSTM 单元加载预训练权重。

最佳答案

这是加载预训练 Caffe 模型的解决方案。查看full code here ,在 this thread 的讨论中被引用.

net_caffe = caffe.Net(prototxt, caffemodel, caffe.TEST)
caffe_layers = {}

for i, layer in enumerate(net_caffe.layers):
    layer_name = net_caffe._layer_names[i]
    caffe_layers[layer_name] = layer

def caffe_weights(layer_name):
    layer = caffe_layers[layer_name]
    return layer.blobs[0].data

def caffe_bias(layer_name):
    layer = caffe_layers[layer_name]
    return layer.blobs[1].data

#tensorflow uses [filter_height, filter_width, in_channels, out_channels] 2-3-1-0 
#caffe uses [out_channels, in_channels, filter_height, filter_width] 0-1-2-3
def caffe2tf_filter(name):
    f = caffe_weights(name)
    return f.transpose((2, 3, 1, 0))

class ModelFromCaffe():
    def get_conv_filter(self, name):
        w = caffe2tf_filter(name)
        return tf.constant(w, dtype=tf.float32, name="filter")

    def get_bias(self, name):
        b = caffe_bias(name)
        return tf.constant(b, dtype=tf.float32, name="bias")

    def get_fc_weight(self, name):
        cw = caffe_weights(name)
        if name == "fc6":
            assert cw.shape == (4096, 25088)
            cw = cw.reshape((4096, 512, 7, 7)) 
            cw = cw.transpose((2, 3, 1, 0))
            cw = cw.reshape(25088, 4096)
        else:
            cw = cw.transpose((1, 0))

        return tf.constant(cw, dtype=tf.float32, name="weight")

images = tf.placeholder("float", [None, 224, 224, 3], name="images")
m = ModelFromCaffe()

with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  batch = cat.reshape((1, 224, 224, 3))
  out = sess.run([m.prob, m.relu1_1, m.pool5, m.fc6], feed_dict={ images: batch })
...

关于python - 如何在 Tensorflow 中加载预训练的 LSTM 模型权重,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38033632/

相关文章:

python - 从 psycopg2 获取字典

python - 合并 Pandas 中不相交的列

python - 禁用 TensorFlow 调试信息

python - tensorflow 错误: how to make tensor A the same graph as Tensor B

python - 在 Keras 中,如何为 LSTM 层获取 3D 输入和 3D 输出

python - 如何提高lstm训练的准确性

python - 通过猴子修补 DEFAULT_PROTOCOL 提高 pickle.dumps 的性能?

Python - 使用变量作为列表名称

tensorflow - 在 Tensorflow Serving 中调试批处理(未观察到效果)

python - Keras——使用 LSTM 层时精度较低,但不使用 LSTM 时精度很好