<分区>
标签 ios tensorflow keras
为什么我们总是使用 , seed = 1234 in tf.compat.v1.random.set_random_seed(seed)。有什么具体原因吗?”
最佳答案
是的,让所有操作生成的随机序列在 session 中可重复。顺便说一句,seed=1234
是随机的。您可以选择任何值。
例如:
a = tf.random.uniform([1])
b = tf.random.normal([1])
# Repeatedly running this block with the same graph will generate the
# different sequences of 'a' and 'b' across sessions.
print("Session 1")
with tf.Session() as sess1:
print(sess1.run(a)) # generates 'A1'
print(sess1.run(a)) # generates 'A2'
print(sess1.run(b)) # generates 'B1'
print(sess1.run(b)) # generates 'B2'
print("Session 2")
with tf.Session() as sess2:
print(sess2.run(a)) # generates 'A3'
print(sess2.run(a)) # generates 'A4'
print(sess2.run(b)) # generates 'B3'
print(sess2.run(b)) # generates 'B4'
输出:
Session 1
[0.35214436]
[0.61644566]
[-0.2290629]
[0.8414659]
Session 2
[0.62713313]
[0.5924448]
[-0.5366475]
[-0.36064562]
但是,在使用 tf.random.set_random_seed(1234)
设置图级种子后:
tf.random.set_random_seed(1234)
a = tf.random.uniform([1])
b = tf.random.normal([1])
# Repeatedly running this block with the same graph will generate the same
# sequences of 'a' and 'b'.
print("Session 1")
with tf.Session() as sess1:
print(sess1.run(a)) # generates 'A1'
print(sess1.run(a)) # generates 'A2'
print(sess1.run(b)) # generates 'B1'
print(sess1.run(b)) # generates 'B2'
print("Session 2")
with tf.Session() as sess2:
print(sess2.run(a)) # generates 'A1'
print(sess2.run(a)) # generates 'A2'
print(sess2.run(b)) # generates 'B1'
print(sess2.run(b)) # generates 'B2'
输出:
Session 1
[0.53202796]
[0.91749656]
[-1.3118125]
[-0.44506428]
Session 2
[0.53202796]
[0.91749656]
[-1.3118125]
[-0.44506428]
您还可以设置操作级种子,如 a = tf.random.uniform([1], seed=1)
。从官方文档 here 中阅读更多相关信息。
关于ios - 为什么我们总是在tf.compat.v1.random.set_random_seed(1234)中使用seed =1234。有什么具体原因吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57301699/
相关文章:
ios - 从 View Controller 访问 TableView 单元格中的 UISegmentedControl
python - Tensorflow:仅在需要时将图像加载到内存中
neural-network - 为什么在 Keras 中 CNN 的训练速度比完全连接的 MLP 慢?
ios - objective-C : UIImagePickerControllerReferenceURL get ID
tensorflow - 使用 3 元素线性数据来训练模型,并包含异常值测试数据。为什么测试准确率仍然是100%?
python - 使用faster_rcnn_nas_coco模型训练时设置 "second_stage_batch_size*"是什么意思?
keras - 如何在 ImageDataGenerator 中将 featurewise_center=True 与 flow_from_directory 一起使用?