python - 将 Tensorflow 1.5 转换为 Tensorflow 2

标签 python tensorflow

如何将此 Tensorflow 1.5 代码转换为 Tensorflow 2?

import tensorflow as tf
try:
    Session = tf.Session
except AttributeError:
    Session = tf.compat.v1.Session
A = random_normal([10000,10000])
B = random_normal([10000,10000])
with Session() as sess:
    print(sess.run(tf.reduce_sum(tf.matmul(A,B))))

主要问题是 Tensorflow 2 中删除了 Session 类,并且 compat.v1 层中公开的版本实际上并不兼容。当我使用 Tensorflow 2 运行此代码时,它现在抛出异常:

RuntimeError: Attempting to capture an EagerTensor without building a function.

如果我完全放弃使用 Session,那么功能上仍然等效吗?如果我运行:

import tensorflow as tf
A = random_normal([10000,10000])
B = random_normal([10000,10000])
with Session() as sess:
    print(tf.reduce_sum(tf.matmul(A,B)))

在支持 AVX2 的 Tensoflow 1.16 中,它的运行速度明显更快(0.005 秒 vs 30 秒),而从 pip 安装的普通 Tensorflow 2(不支持 AVX2)也运行得更快一些(30 秒 vs 60 秒)。

为什么使用 Session 会使 Tensorflow 1.16 减慢 6000 倍?

最佳答案

您当然应该利用 TF 2.x 的优势,包括 Eager Execution。不仅非常方便,而且效率更高。

import tensorflow as tf

def get_values():
  A = tf.random.normal([10_000,10_000])
  B = tf.random.normal([10_000,10_000])
  return A,B

@tf.function
def compute():
  A,B = get_values()
  return tf.reduce_sum(tf.matmul(A,B))

print(compute())

您(大多数情况下)在 TF 2.x 中不再需要任何 session ,Auto Graph 会自动为您完成此操作。

只需使用 @tf.function 注释“main”函数(无需注释诸如 get_values 之类的其他函数,这也会自动发生)。

关于python - 将 Tensorflow 1.5 转换为 Tensorflow 2,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58459444/

相关文章:

python - 如何释放 mongodb 连接?

python - 使用 python mraa 读取英特尔爱迪生上的温度传感器值

python - 分类 : What happens if one class has 4 times as much data as the other class?

machine-learning - 在 tensorflow 中使用 get_variable 进行偏差的零初始化

python - 通过添加 timedelta 创建具有唯一值的 DatetimeIndex 的 DataFrame

python - 使用一个分隔符但多个条件分割字符串

python - Pandas 一次迭代多行重叠

tensorflow - 如何将 keras 层连接到最后一个轴之外

c++ - 在多 session 场景中,如何从一个 tf::Session() 中隐藏 GPU?

具有混合数据类型的 TensorFlow 数据集生成器