weights = tf.Variable(tf.truncated_normal([2,3]))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print('Weights:')
print(sess.run(weights))
print("{0:2f}".format(sess.run(weights)))
第一个打印语句按预期工作。
Weights:
[[ 0.30919516 0.29567152 0.11157229]
[ 0.26642913 -0.21269836 -0.58886886]]
使用 str.format() 的第二个打印语句给出以下错误。
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-20-b0e4f93d01f8> in <module>()
4 print('Weights:')
5 print(sess.run(weights))
----> 6 print("{0:2f}".format(sess.run(weights)))
> TypeError: non-empty format string passed to object.__format__
除了下面的答案之外,我还发现 np.set_printoptions( precision=2) 也有效。
np.set_printoptions(precision=2)
weights = tf.Variable(tf.truncated_normal([2,3]))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print('Weights:')
print(sess.run(weights))
Weights:
[[ 0.01 -0.42 -1.57]
[-0.44 1.62 0.27]]
最佳答案
您尝试打印整个数组,并且 format
需要可以表示为 float 的单个值。尝试这样:
print(np.around(sess.run(weights), 2)
#[[ 0.31 0.30 0.11]
# [ 0.27 -0.21 -0.59]]
此外,正确的格式是0:.2f
关于python - 如何将 tensorflow 变量打印到小数点后两位?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48710409/