python - 从检查点恢复时,如何更改参数的数据类型?

标签 python tensorflow machine-learning

我有一个预训练的Tensorflow检查点,其中的参数都是float32数据类型。

如何将检查点参数加载为 float16?或者有没有办法修改检查点的数据类型?

以下是我的代码片段,它试图将 float32 检查点加载到 float16 图形中,但出现类型不匹配错误。

import tensorflow as tf

A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1])  # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float32_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(dense))
    save_path = saver.save(sess, "tmp.ckpt")

tf.reset_default_graph()
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1])  # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float16_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)

with tf.Session() as sess:
    saver.restore(sess, "tmp.ckpt")
    print(sess.run(dense))
    pass

# errors:
# tensor_name = dense/bias:0; expected dtype half does not equal original dtype float
# tensor_name = dense/kernel:0; expected dtype half does not equal original dtype float
# tensor_name = foo:0; expected dtype half does not equal original dtype float

最佳答案

深入了解 how savers work ,似乎您可以通过 builder 对象重新定义它们的构造。例如,您可以有一个构建器将值加载为 tf.float32,然后将它们转换为变量的实际类型:

import tensorflow as tf
from tensorflow.python.training.saver import BaseSaverBuilder

class CastFromFloat32SaverBuilder(BaseSaverBuilder):
  # Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restore
  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
                   restore_sequentially):
    from tensorflow.python.ops import io_ops
    restore_specs = []
    for saveable in saveables:
      for spec in saveable.specs:
        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
    names, slices, dtypes = zip(*restore_specs)
    restore_dtypes = [tf.float32 for _ in dtypes]
    with tf.device("cpu:0"):
      restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes)
      return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]

请注意,这假设所有恢复的变量都是 tf.float32。如有必要,您可以根据您的用例适当调整构建器,例如在构造函数中传递一个或多个源类型等。有了这个,您只需要在第二个保护程序中使用上面的构建器就可以让您的示例工作:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
    dense = tf.layers.dense(inputs=A, units=3)
    varis = tf.trainable_variables(scope=None)
    assign = {vari.name: vari for vari in varis}
    saver = tf.train.Saver(assign)
    sess.run(tf.global_variables_initializer())
    print('Value to save:')
    print(sess.run(dense))
    save_path = saver.save(sess, "ckpt/tmp.ckpt")

with tf.Graph().as_default(), tf.Session() as sess:
    A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
    dense = tf.layers.dense(inputs=A, units=3)
    varis = tf.trainable_variables(scope=None)
    assign = {vari.name: vari for vari in varis}
    saver = tf.train.Saver(assign, builder=CastFromFloat32SaverBuilder())
    saver.restore(sess, "ckpt/tmp.ckpt")
    print('Restored value:')
    print(sess.run(dense))

输出:

Value to save:
[[ 0.50589913  0.33701038 -0.11597633]
 [ 0.27372625  0.27724823  0.49825498]
 [ 1.0897961  -0.29577428 -0.9173869 ]]
Restored value:
[[ 0.506    0.337   -0.11597]
 [ 0.2737   0.2773   0.4983 ]
 [ 1.09    -0.296   -0.9175 ]]

关于python - 从检查点恢复时,如何更改参数的数据类型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56557084/

相关文章:

python - 训练后Keras合并单张图片的批处理

python - anaconda环境无法卸载包

python - 用于显示表单数据和验证错误的最干净的设计模式?

python - 从 python 中的数据帧矩阵打印标题值

python - 将抓取 URL 从一个蜘蛛传递到另一个蜘蛛

Python:高效调用多个返回函数的子集变量

TensorFlow - 具有归一化约束的优化

graph - 如何使用机器学习学习图

python-3.x - 如何改进决策树回归器中的负 R 平方

machine-learning - 模糊逻辑、人工智能、机器学习、深度学习