python - 在多种选项中,如何在 TensorFlow 中随机执行其中一种?

标签 python tensorflow

如何以可训练的方式从多个备选方案中随机选择一个执行流程?例如:

import random
from tensorflow import keras

class RandomModel(keras.Model):
    def __init__(self, model_set):
        super(RandomModel, self).__init__()
        self.models = model_set


    def call(self, inputs):
        """Calls one of its models at random"""
        return random.sample(self.models, 1)[0](inputs)


def new_model():
    return keras.Sequential([
        keras.layers.Dense(10, activation='softmax')
    ])

model = RandomModel({new_model(), new_model()})
model.build(input_shape=(32, 784))
model.summary()

虽然此代码 runs ,它似乎不允许梯度反向传播。这是它的输出:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________

最佳答案

我找到了a way去做这个。但是,由于嵌套的 tf.cond 操作,执行速度很慢:

def random_network_applied_to_inputs(inputs, networks):
    """
    Returns a tf.cond tree that does binary search
    before applying a network to the inputs.
    """
    length = len(networks)

    index = tf.random.uniform(
        shape=[],
        minval=0,
        maxval=length,
        dtype=tf.dtypes.int32
    )

    def branch(lower_bound, upper_bound):
        if lower_bound + 1 == upper_bound:
            return networks[lower_bound](inputs)
        else:
            center = (lower_bound + upper_bound) // 2
            return tf.cond(
                pred=index < center,
                true_fn=lambda: branch(lower_bound, center),
                false_fn=lambda: branch(center, upper_bound)
            )

    return branch(0, length)

关于python - 在多种选项中,如何在 TensorFlow 中随机执行其中一种?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55618901/

相关文章:

tensorflow - tf.data.Dataset 是否支持生成字典结构?

python - ArangoDB 读取超时(读取超时=60)

python - 如何使用 SimplyFold 在 vi​​m 中的折叠文本中维护语法突出显示?

python - 如何终止在后端执行长时间运行的 C/C++ 代码的 python 解释器?

node.js - 将base64图像转换为张量

python-3.x - 刚刚切换到 TensorFlow 2.1 并收到一些烦人的警告

python - Django Rest Framework 分页提供小于页面大小

Python:只设置存在性检查?

python - 在 Keras 中制作采样层

python - 在 El Capitan 10.11.6 上安装 Tensorflow 1.10