python - Tensorflow:无法将 tf.case 与输入参数一起使用

标签 python tensorflow piecewise

我需要创建一个变量epsilon_n,它根据当前步骤更改定义(和值)。由于我有两个以上的情况,似乎我无法使用 tf.cond 。我尝试按如下方式使用 tf.case:

import tensorflow as tf

####
EPSILON_DELTA_PHASE1 = 33e-4
EPSILON_DELTA_PHASE2 = 2.5
####
step = tf.placeholder(dtype=tf.float32, shape=None)


def fn1(step):
    return tf.constant([1.])

def fn2(step):
    return tf.constant([1.+step*EPSILON_DELTA_PHASE1])

def fn3(step):
    return tf.constant([1.+step*EPSILON_DELTA_PHASE2])

epsilon_n = tf.case(
        pred_fn_pairs=[
            (tf.less(step, 3e4), lambda step: fn1(step)),
            (tf.less(step, 6e4), lambda step: fn2(step)),
            (tf.less(step, 1e5), lambda step: fn3(step))],
            default=lambda: tf.constant([1e5]),
        exclusive=False)

但是,我不断收到此错误消息:

TypeError: <lambda>() missing 1 required positional argument: 'step'

我尝试了以下方法:

epsilon_n = tf.case(
        pred_fn_pairs=[
            (tf.less(step, 3e4), fn1),
            (tf.less(step, 6e4), fn2),
            (tf.less(step, 1e5), fn3)],
            default=lambda: tf.constant([1e5]),
        exclusive=False)

我还是会犯同样的错误。 Tensorflow 文档中的示例适用于没有输入参数传递给可调用函数的情况。我在互联网上找不到有关 tf.case 的足够信息!请帮忙吗?

最佳答案

您需要进行一些更改。 为了保持一致性,您可以将所有返回值设置为变量。

# Since step is a scalar, scalar shape [() or [], not None] much be provided 
step = tf.placeholder(dtype=tf.float32, shape=())


def fn1(step):
    return tf.constant([1.])

# Here you need to use Variable not constant, since you are modifying the value using placeholder
def fn2(step):
    return tf.Variable([1.+step*EPSILON_DELTA_PHASE1])

def fn3(step):
    return tf.Variable([1.+step*EPSILON_DELTA_PHASE2])

epsilon_n = tf.case(
    pred_fn_pairs=[
        (tf.less(step, 3e4), lambda : fn1(step)),
        (tf.less(step, 6e4), lambda : fn2(step)),
        (tf.less(step, 1e5), lambda : fn3(step))],
        default=lambda: tf.constant([1e5]),
    exclusive=False)

关于python - Tensorflow:无法将 tf.case 与输入参数一起使用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45886201/

相关文章:

python - 将 __getitem__ 添加到模块

python - 如何组合/集成存储在 3 个数据帧中的 3 个机器学习模型的结果并输出 1 个数据帧,其结果得到多数人同意?

python - BERT 和 ALBERT 的训练数据损失大且准确率低

python-3.x - 使用 Keras 和 TensorFlow 后端可重现结果

python - 分段线性函数与 numpy.piecewise

python - 将两列文本文件中的 float 读取到 Python 中的数组中时出错

python - 在 Python 中计算矢量场的旋度并使用 matplotlib 绘制它

tensorflow - 在 TensorFlow 中重新训练卡住的 *.pb 模型

python - python中的分段列表理解

python - numpy.piecewise 中的多个部分