python - Tensorflow 循环神经网络 : how to infer a sequence without duplicates?

标签 python tensorflow recurrent-neural-network sequence-to-sequence

我正在使用 seq2seq RNN 生成给定种子标签的标签输出序列。在推理步骤中,我想生成仅包含唯一标签的序列(即跳过已经添加到输出序列的标签)。为此,我创建了一个采样器对象,它试图记住已添加到输出中的标签并将它们的 logit 值减少到 -np.inf

这是采样器代码:

class InferenceSampler(object):
    def __init__(self, out_weights, out_biases):
        self._out_weights = tf.transpose(out_weights)
        self._out_biases = out_biases

        self._n_tracks = out_weights.shape[0]
        self.ids_mask = tf.zeros([self._n_tracks], name="playlist_mask")

    def __call__(self, decoder_outputs):
        _logits = tf.matmul(decoder_outputs, self._out_weights)
        _logits = tf.nn.bias_add(_logits, self._out_biases)

        # apply mask
        _logits = _logits + self.ids_mask

        _sample_ids = tf.cast(tf.argmax(_logits, axis=-1), tf.int32)

        # update mask
        step_ids_mask = tf.sparse_to_dense(_sample_ids, [self._n_tracks], -np.inf)
        self.ids_mask = self.ids_mask + step_ids_mask

        return _sample_ids

推理图的代码如下所示:

self._max_playlist_len = tf.placeholder(tf.int32, ())
self._start_tokens = tf.placeholder(tf.int32, [None])

sample_fn = InferenceSampler(out_weights, out_biases)
with tf.name_scope("inf_decoder"):
    def _end_fn(sample_ids):
        return tf.equal(sample_ids, PAD_ITEM_ID)

    def _next_inputs_fn(sample_ids):
        return tf.nn.embedding_lookup(
            track_embs,
            sample_ids
        )

    _start_inputs = tf.nn.embedding_lookup(
        track_embs,
        self._start_tokens
    )

    helper = tf.contrib.seq2seq.InferenceHelper(
        sample_fn=sample_fn,
        sample_shape=[],
        sample_dtype=tf.int32,
        start_inputs=_start_inputs,
        end_fn=_end_fn,
        next_inputs_fn=_next_inputs_fn
    )
    decoder = tf.contrib.seq2seq.BasicDecoder(
        rnn_cell,
        helper,
        rnn_cell.zero_state(tf.shape(self._start_tokens)[0], tf.float32),
        output_layer=projection_layer
    )
    outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
        decoder,
        maximum_iterations=self._max_playlist_len
    )

self.playlists = outputs.sample_id

不幸的是,结果仍然有重复的标签。此外,当我尝试访问 sample_fn.ids_mask 时,我收到一条错误消息:ValueError: Operation 'inf_decoder/decoder/while/BasicDecoderStep/add_1' has been marked as not fetchable。

我做错了什么?创建这样的 sample_fn 是否合法?

最佳答案

为了克服这个问题,我更新了推理,在每个 RNN 步骤中输出嵌入向量而不是 item_id。推理完成后,我将嵌入转换为 item_ids

首先,该解决方案最大限度地减少了操作次数。其次,由于我使用 LSTM/GRU 单元,它们最大限度地降低了在 RNN 推理的不同步骤中观察到两个绝对相似输出的概率。

新代码如下所示:

with tf.name_scope("inf_decoder"):
    def _sample_fn(decoder_outputs):
        return decoder_outputs

    def _end_fn(sample_ids):
        # infinite
        return tf.tile([False], [n_seeds])

    _start_inputs = tf.nn.embedding_lookup(
        track_embs,
        self._seed_items
    )

    helper = tf.contrib.seq2seq.InferenceHelper(
        sample_fn=_sample_fn,
        sample_shape=[self.emb_size],
        sample_dtype=tf.float32,
        start_inputs=_start_inputs,
        end_fn=_end_fn,
    )
    decoder = tf.contrib.seq2seq.BasicDecoder(
        rnn_cell,
        helper,
        rnn_cell.zero_state(n_seeds, tf.float32),
        output_layer=projection_layer
    )
    outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
        decoder,
        maximum_iterations=self._max_playlist_len
    )

flat_rnn_output = tf.reshape(outputs.rnn_output, [-1, self.emb_size])
flat_logits = tf.matmul(flat_rnn_output, out_weights, transpose_b=True)
flat_logits = tf.nn.bias_add(flat_logits, out_biases)

item_ids = tf.cast(tf.argmax(flat_logits, axis=-1), tf.int32)
playlists = tf.reshape(item_ids, [n_seeds, -1])

self.playlists = playlists

关于python - Tensorflow 循环神经网络 : how to infer a sequence without duplicates?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47795391/

相关文章:

python - 循环函数

python - 使用 Python sshtunnel 进行端口转发休息请求

Python - 我可以通过编程方式装饰类实例中的类方法吗?

python - 如何在 Tensorflow 中仅初始化优化器变量?

machine-learning - 使用机器学习计算图像中对象的多个实例

tensorflow - TF 2.0 SparseCategoricalCrossEntropy 奇怪的行为

python - TensorFlow 动态 RNN 未训练

python - 如何在一列和一个索引上连接两个 Pandas 数据框

tensorflow - 如何获取tensorflow RNN的摘要信息

machine-learning - 将句子输入 RNN 时是否应该删除停用词