我有一个预训练的Tensorflow检查点,其中的参数都是float32数据类型。
如何将检查点参数加载为 float16?或者有没有办法修改检查点的数据类型?
以下是我的代码片段,它试图将 float32 检查点加载到 float16 图形中,但出现类型不匹配错误。
import tensorflow as tf
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1]) # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float32_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(dense))
save_path = saver.save(sess, "tmp.ckpt")
tf.reset_default_graph()
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1]) # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float16_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)
with tf.Session() as sess:
saver.restore(sess, "tmp.ckpt")
print(sess.run(dense))
pass
# errors:
# tensor_name = dense/bias:0; expected dtype half does not equal original dtype float
# tensor_name = dense/kernel:0; expected dtype half does not equal original dtype float
# tensor_name = foo:0; expected dtype half does not equal original dtype float
最佳答案
深入了解 how savers work ,似乎您可以通过 builder
对象重新定义它们的构造。例如,您可以有一个构建器将值加载为 tf.float32
,然后将它们转换为变量的实际类型:
import tensorflow as tf
from tensorflow.python.training.saver import BaseSaverBuilder
class CastFromFloat32SaverBuilder(BaseSaverBuilder):
# Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restore
def bulk_restore(self, filename_tensor, saveables, preferred_shard,
restore_sequentially):
from tensorflow.python.ops import io_ops
restore_specs = []
for saveable in saveables:
for spec in saveable.specs:
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
names, slices, dtypes = zip(*restore_specs)
restore_dtypes = [tf.float32 for _ in dtypes]
with tf.device("cpu:0"):
restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes)
return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]
请注意,这假设所有恢复的变量都是 tf.float32
。如有必要,您可以根据您的用例适当调整构建器,例如在构造函数中传递一个或多个源类型等。有了这个,您只需要在第二个保护程序中使用上面的构建器就可以让您的示例工作:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
assign = {vari.name: vari for vari in varis}
saver = tf.train.Saver(assign)
sess.run(tf.global_variables_initializer())
print('Value to save:')
print(sess.run(dense))
save_path = saver.save(sess, "ckpt/tmp.ckpt")
with tf.Graph().as_default(), tf.Session() as sess:
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
assign = {vari.name: vari for vari in varis}
saver = tf.train.Saver(assign, builder=CastFromFloat32SaverBuilder())
saver.restore(sess, "ckpt/tmp.ckpt")
print('Restored value:')
print(sess.run(dense))
输出:
Value to save:
[[ 0.50589913 0.33701038 -0.11597633]
[ 0.27372625 0.27724823 0.49825498]
[ 1.0897961 -0.29577428 -0.9173869 ]]
Restored value:
[[ 0.506 0.337 -0.11597]
[ 0.2737 0.2773 0.4983 ]
[ 1.09 -0.296 -0.9175 ]]
关于python - 从检查点恢复时,如何更改参数的数据类型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56557084/