这是我想用 tensorflow 实现 f(x)
输入x = (x1,x2,x3,x4,x5,x6,x7,x8,x9)
定义f(x) = f1(x1,x2,x3,x4,x5) + f2(x5,x6,x7,x8,x9)
哪里
f1(x1,x2,x3,x4,x5) = {1 if (x1,x2,x3,x4,x5)=(0,0,0,0,0),
g1(x1,x2,x3,x4,x5) otherwise}
f2(x5,x6,x7,x8,x9) = {1 if (x5,x6,x7,x8,x9)=(0,0,0,0,0),
g2(x5,x6,x7,x8,x9) otherwise}
这是我的 tensorflow 代码
import tensorflow as tf
import numpy as np
ph = tf.placeholder(dtype=tf.float32, shape=[None, 9])
x1 = tf.slice(ph, [0, 0], [-1, 5])
x2 = tf.slice(ph, [0, 4], [-1, 5])
fixed1 = tf.placeholder(dtype=tf.float32, shape=[1, 5])
fixed2 = tf.placeholder(dtype=tf.float32, shape=[1, 5])
# MLP 1
w1 = tf.Variable(tf.ones([5, 1]))
g1 = tf.matmul(x1, w1)
# MLP 2
w2 = tf.Variable(-tf.ones([5, 1]))
g2 = tf.matmul(x2, w2)
check1 = tf.reduce_all(tf.equal(x1, fixed1), axis=1, keep_dims=True)
check2 = tf.reduce_all(tf.equal(x2, fixed2), axis=1, keep_dims=True)
#### with Problem
f1 = tf.cond(check1,
lambda: tf.constant([2], dtype=tf.float32), lambda: g1)
f2 = tf.cond(check2,
lambda: tf.constant([1], dtype=tf.float32), lambda: g2)
####
f = tf.add(f1, f2)
x = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 0, 0, 0],
[2, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 0, 0, 0, 0, 0, 0, 0]])
fixed = np.array([[0, 0, 0, 0, 0]])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print('(1)\n', sess.run(check1, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))
print('(2)\n', sess.run(check2, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))
print('(3)\n', sess.run(f, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))
print('(4)\n', sess.run(f1, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))
print('(5)\n', sess.run(f2, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))
<小时/>
在这种情况下,
check1 是 [[ True], [ True], [False], [False], [False]],形状为 (5, 1)
check2 是 [[ True], [False], [ True], [ True], [ True]],形状为 (5, 1)
我期望 f 的结果是 [[3], [1], [2], [3], [10]]
但似乎 tf.cond() 无法将输入处理为形状为 (5, 1) 的 bool 张量
请您建议如何使用 tensorflow 实现 f(x)。
这是我收到的错误消息
Traceback (most recent call last): File "C:\Users\hong\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\common_shapes.py", line 670, in _call_cpp_shape_fn_impl status) File "C:\Users\hong\AppData\Local\Continuum\Anaconda3\lib\contextlib.py", line 66, in exit next(self.gen) File "C:\Users\hong\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 469, in raise_exception_on_not_ok_status pywrap_tensorflow.TF_GetCode(status)) tensorflow.python.framework.errors_impl.InvalidArgumentError: Shape must be rank 0 but is rank 2 for 'cond/Switch' (op: 'Switch') with input shapes: [?,1], [?,1].
During handling of the above exception, another exception occurred:
Traceback (most recent call last): File "C:/Users/hong/Dropbox/MLILAB/Research/GM-MLP/code/tensorflow_cond.py", line 23, in lambda: tf.constant([2], dtype=tf.float32), lambda: g1) File "C:\Users\hong\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 1765, in cond p_2, p_1 = switch(pred, pred) File "C:\Users\hong\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 318, in switch return gen_control_flow_ops._switch(data, pred, name=name) File "C:\Users\hong\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_control_flow_ops.py", line 368, in _switch result = _op_def_lib.apply_op("Switch", data=data, pred=pred, name=name) File "C:\Users\hong\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 759, in apply_op op_def=op_def) File "C:\Users\hong\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 2242, in create_op set_shapes_for_outputs(ret) File "C:\Users\hong\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1617, in set_shapes_for_outputs shapes = shape_func(op) File "C:\Users\hong\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1568, in call_with_requiring return call_cpp_shape_fn(op, require_shape_fn=True) File "C:\Users\hong\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\common_shapes.py", line 610, in call_cpp_shape_fn debug_python_shape_fn, require_shape_fn) File "C:\Users\hong\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\common_shapes.py", line 675, in _call_cpp_shape_fn_impl raise ValueError(err.message) ValueError: Shape must be rank 0 but is rank 2 for 'cond/Switch' (op: 'Switch') with input shapes: [?,1], [?,1].
Process finished with exit code 1
最佳答案
我认为你需要 tf.where,而不是 tf.cond。
关于machine-learning - 我如何将 bool 张量输入到 tf.cond() 而不仅仅是一个 bool 值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43203075/