python - Tensorflow 可以进行条件计算吗?

标签 python tensorflow deep-learning mxnet

我正在使用 MxNet 开发条件计算框架。假设我们的小批量中有 N 个样本。我需要使用伪代码在我的计算图中执行此类操作:

x = graph.Variable("x")
y = graph.DoSomeTranformations(x)
# The following operation generates a Nxk sized matrix, k responses for each sample.
z = graph.DoDecision(y)
for i in range(k):
   argmax_sample_indices_for_i = graph.ArgMaxIndices(z, i)
   y_selected_samples = graph.TakeSelectedSample(y, argmax_sample_indices_for_i )
   result = graph.DoSomeTransformations(y_selected_samples)

我想要实现的目标如下:获得 y 后,我应用决策函数(这可以是 D 到 k 全连接层,其中 D 是数据维度)并为我的每个样本获得 k 个激活N 大小的小批量。然后,我想根据每个样本的最大激活的列索引,动态地将我的小批量分成 k 个不同的部分(k 可以是 2、3、一个小数字)。我假设的“graph.ArgMaxIndices”函数执行此操作,给定 z、Nxk 大小的矩阵和 i,该函数查找沿 i 列给出最大激活的样本索引并返回它们的索引。 (请注意,我寻找与“graph.ArgMaxIndices”给出等效结果的任何系列或函数组合,而不是单个函数)。最后,对于每个 i,我选择具有最大激活值的样本并对它们应用特定的转换。目前,据我所知,MxNet 在其符号网络中不支持此类条件计算。因此,我在每次决策后构建单独的符号图,并且必须为每个小批量分割编写单独的簿记 - 条件图结构,这会产生 1) 维护和开发非常复杂且繁琐的代码 2) 训练和评估期间运行性能下降。

我的问题是,我可以使用 Tensorflow 的符号运算符来完成上述操作吗?它是否允许人们根据一项标准选择小批量的子集?是否有一个函数或一系列函数相当于上面伪代码中的“graph.ArgMaxIndices”? (给定 Nxk 矩阵和列索引 i,返回在第 k 列具有最大激活值的行索引)。

最佳答案

您可以在 Tensorflow 中做到这一点。

我认为最好的方法是使用面具和 tf.boolean_mask k 次,第 i 个掩码由 tf.equal(i, tf.argmax(z, axis=-1)) 给出

x = graph.Variable("x")
y = graph.DoSomeTranformations(x)
# The following operation generates a Nxk sized matrix, k responses for each sample.
z = graph.DoDecision(y)
max_indices = tf.argmax(z, axis=-1)
for i in range(k):
   argmax_sample_indices_for_i = tf.equal(i, max_indices)
   y_selected_samples = tf.boolean_mask(y, mask=argmax_sample_indices_for_i )
   result = graph.DoSomeTransformations(y_selected_samples)

关于python - Tensorflow 可以进行条件计算吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44567799/

相关文章:

python - 在 Tkinter 中如何将被调用函数作为参数传递?

python - 过滤具有 2 级列标题的 Pandas 数据框

python - 如何在 python 中创建一个在 while 循环中每次累加的函数

tensorflow - tensorflow 中.pb和.pbtxt之间的区别?

python - REDIS:python 中的 redis 不返回任何内容

python - 如何解决tensorflow 1.13.1和python 3.7中的 'import pycocotools._mask is not a valid win32 application'?

tensorflow - 尝试将 'n' 转换为张量但失败。错误: None values not supported

tensorflow - 何时使用 @tf.function 装饰器,何时不使用?我知道 tf.function 构建图形。但是如何知道何时构建图呢?

python - 部署caffe回归模型

machine-learning - Caffe的输入数据标准化