python - Tensorflow实现crf损失

标签 python tensorflow crf

我正在尝试在 Tensorflow 图中使用条件随机场损失。

我正在执行序列标记任务:

我有一个元素序列作为输入[A, B, C, D]。每个元素可以属于 3 个不同类别中的一个。 类以 one-hot 编码方式表示:属于类 0 的元素由向量 [1, 0, 0] 表示。

我的输入标签 (y) 的大小为 (batch_size x sequence_length x num_classes)。

我的网络生成具有相同形状的 logits。

假设我所有的序列长度都是 4。

这是我的代码:

import tensorflow as tf

sequence_length = 4
num_classes = 3
input_y = tf.placeholder(tf.int32, shape=[None, sequence_length, num_classes])
logits = tf.placeholder(tf.float32, shape=[None, None, num_classes])
dense_y = tf.argmax(input_y, -1, output_type=tf.int32)

log_likelihood, _ = tf.contrib.crf.crf_log_likelihood(logits, dense_y, sequence_length)

我收到以下错误:

File "", line 1, in File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/crf/python/ops/crf.py", line 182, in crf_log_likelihood transition_params) File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/crf/python/ops/crf.py", line 109, in crf_sequence_score false_fn=_multi_seq_fn) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/layers/utils.py", line 206, in smart_cond pred, true_fn=true_fn, false_fn=false_fn, name=name) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/smart_cond.py", line 59, in smart_cond name=name) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/util/deprecation.py", line 432, in new_func return func(*args, **kwargs) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2063, in cond orig_res_t, res_t = context_t.BuildCondBranch(true_fn) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1913, in BuildCondBranch original_result = fn() File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/crf/python/ops/crf.py", line 95, in _single_seq_fn array_ops.concat([example_inds, tag_indices], axis=1)) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 2975, in gather_nd "GatherNd", params=params, indices=indices, name=name) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper op_def=op_def) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 3392, in create_op op_def=op_def) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1734, in init control_input_ops) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1570, in _create_c_op raise ValueError(str(e)) ValueError: indices.shape[-1] must be <= params.rank, but saw indices shape: [?,5] and params shape: [?,3] for 'cond/GatherNd' (op: 'GatherNd') with input shapes: [?,3], [?,5]

最佳答案

该错误是由于序列长度变量的维度错误造成的。它必须是向量,而不是标量。

import tensorflow as tf

num_classes = 3
input_x = tf.placeholder(tf.int32, shape=[None, None], name="input_x")
input_y = tf.placeholder(tf.int32, shape=[None, sequence_length, num_classes])
sequence_length = tf.reduce_sum(tf.sign(input_x), 1)

# After some network operation you will come up with logits

logits = tf.placeholder(tf.float32, shape=[None, None, num_classes])
dense_y = tf.argmax(input_y, -1, output_type=tf.int32)
log_likelihood, _ = tf.contrib.crf.crf_log_likelihood(logits, dense_y, sequence_length

关于python - Tensorflow实现crf损失,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51105062/

相关文章:

python - Python-无法在此特定代码中将int隐式转换为str

python - Pandas:合并名称相似的列

machine-learning - variable_ops_scope和variable_scope有什么区别?

python - Tensorflow、Keras 预训练的 MobileNetV2 模型无法下载

tensorflow - 如何在 Tensorflow 和 numpydoc 中使用 intersphinx?

opennlp - 如何创建基于地名词典的命名实体识别 (NER) 系统?

machine-learning - CRF(条件随机场)可以用来标记整个句子吗?

python - 如何在Python中使用上下文管理器

mallet - 如何使用先前 token 的标签作为我的 CRF 中的功能?

python - TypeError : object. __new__() 没有参数