python - 在内存中序列化和反序列化 Tensorflow 模型并继续训练

标签 python serialization tensorflow keras

我已经看到了这个问题的变体,但我还没有找到令人满意的答案。基本上,我想从 keras model.to_json()model.get_weights()model.from_json() 中执行等效操作, model.set_weights() 到 tensorflow 。我想我快接近那里了,但我正处于被困的地步。如果我能在同一个字符串中获得权重和图形,我更愿意,但我理解这是否不可能。

目前,我拥有的是:

g = optimizer.minimize(loss_op,
                       global_step=tf.train.get_global_step())
de = g.graph.as_graph_def()
json_string = json_format.MessageToJson(de)

gd = tf.GraphDef()
gd = json_format.Parse(json_string, gd)

这似乎可以很好地创建图形,但显然元图不包括变量、权重等。还有元图,但我唯一看到的是 export_meta_graph,它似乎没有序列化同样的方式。我看到 MetaGraph 有一个原型(prototype)函数,但我不知道如何序列化那些变量。

简而言之,您将如何采用 tensorflow 模型(权重、图形等模型),将其序列化为字符串(最好是 json),然后反序列化并继续训练或提供预测。

这里有一些让我接近那里的东西,我已经尝试过,但大多数情况下在需要写入磁盘方面有限制,在这种情况下我不能这样做:

Gist on GitHub

This is the closest one I found, but the link to serializing a metagraph doesn't exist.

最佳答案

请注意,@Maxim 的解决方案每次运行时都会在图中创建新的操作。

如果您非常频繁地运行该函数,这将导致您的代码变得越来越慢。

解决此问题的两种解决方案:

  1. 与图的其余部分同时创建分配操作并重用它们:

    assign_ops = [] 对于 tf.trainable_variables() 中的 var_name: assign_placeholder = tf.placeholder(var.dtype, shape=value.shape) assign_op = var.assign(assign_placeholder) assign_ops.append(assign_op)

  2. 对变量使用加载函数,我更喜欢这个,因为它不需要上面的代码:

    self.params = tf.trainable_variables()

    def get_weights(self): values = tf.get_default_session().run(self.params) 返回值

    def set_weights( self ,权重): 对于 i,枚举值(权重): 值 = np.asarray(值) self.params[i].load(value, self.sess)

(我不能发表评论,所以我把它作为答案)

关于python - 在内存中序列化和反序列化 Tensorflow 模型并继续训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47716395/

相关文章:

python - 验证序列化器 Django Rest 框架中的外键字段

python - 为什么 print() 在 Visual Studio Code 中不起作用?

python - 奇怪的 numpy fft 性能

java - 如何在不使用文件的情况下序列化对象(例如 HashMap)?

java - 如何使用 @Configurable 结合 readResolve() 来注入(inject)依赖项

python - 在 C++ 中为 Tensorflow 模型定义一个 feed_dict

installation - Tensorflow安装错误: not a supported wheel on this platform

python - 如何将 Tensorflow 模型转换为 tensorflow.js 模型?

python - docker python flask 得到 "Do you want to continue"然后 "executor failed"

Python:组合字符串和列表