python - 如何使用 tf.where() 根据条件替换特定值

标签 python tensorflow

我想根据条件替换值。
NumPy 版本会像这样

intensity=np.where(
  np.abs(intensity)<1e-4,
  1e-4,
  intensity)

但是 TensorFlow 对 tf.where() 的用法有点不同
当我尝试这个时

intensity=tf.where(
  tf.math.abs(intensity)<1e-4,
  1e-4,
  intensity)

我收到这个错误

ValueError: Shapes must be equal rank, but are 0 and 4 for 'Select' (op: 'Select') with input shapes: [?,512,512,1], [], [?,512,512,1].

这是否意味着我应该为 1e-4 使用 4 维张量?

最佳答案

以下代码传递了错误

# Create an array which has small value (1e-4),  
# whose shape is (2,512,512,1)
small_val=np.full((2,512,512,1),1e-4).astype("float32")

# Convert numpy array to tf.constant
small_val=tf.constant(small_val)

# Use tf.where()
intensity=tf.where(
  tf.math.abs(intensity)<1e-4,
  small_val,
  intensity)

# Error doesn't occur
print(intensity.shape)
# (2, 512, 512, 1)

关于python - 如何使用 tf.where() 根据条件替换特定值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55393557/

相关文章:

python - 计算线性回归问题中的权重

tensorflow 每次运行发现多个图形事件

python - 为什么 TensorFlow 返回 [[nan nan]] 而不是 CSV 文件中的概率?

python - 在Python中动态实例化类

python - 将 Pyqt GUI 主应用程序作为单独的非阻塞进程运行

python - python中的滚动窗口

tensorflow - 在TensorFlow中使用多个图形

machine-learning - 在 keras 中保存模型时出现 Nonetype 错误

python - Azure 应用程序服务 python 模块导入构建时出现错误

python - 在 Python 中实现线性回归