tensorflow - 我从这个 TensorFlow 的 csv 阅读器中遗漏了什么?

标签 tensorflow

它主要是从网站上的教程复制粘贴。我收到一个错误:

Invalid argument: ConcatOp : Expected concatenating dimensions in the range [0, 0), but got 0 [[Node: concat = Concat[N=4, T=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](concat/concat_dim, DecodeCSV, DecodeCSV:1, DecodeCSV:2, DecodeCSV:3)]]



我的 csv 文件的内容是:

3,4,1,8,4


 import tensorflow as tf


filename_queue = tf.train.string_input_producer(["test2.csv"])

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
    value, record_defaults=record_defaults)
# print tf.shape(col1)

features = tf.concat(0, [col1, col2, col3, col4])
with tf.Session() as sess:
  # Start populating the filename queue.
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)

  for i in range(1200):
    # Retrieve a single instance:
    example, label = sess.run([features, col5])

  coord.request_stop()
  coord.join(threads)

最佳答案

问题是由于程序中张量的形状引起的。 TL;博士 而不是 tf.concat()你应该使用 tf.pack() ,这将转换四个标量 col张量转换为长度为 4 的一维张量。

在我们开始之前,请注意您可以使用 get_shape()任何 Tensor 上的方法对象以获取有关该张量的静态形状信息。例如,代码中注释掉的行可能是:

print col1.get_shape()
# ==> 'TensorShape([])' - i.e. `col1` is a scalar.
value reader.read() 返回的张量是一个标量字符串。 tf.decode_csv(value, record_defaults=[...])record_defaults 的每个元素产生,与 value 形状相同的张量,即在这种情况下的标量。标量是具有单个元素的 0 维张量。 tf.concat(i, xs) 未在标量上定义:它将 N 维张量列表 ( xs ) 连接成一个新的 N 维张量,沿维度 i ,其中 0 <= i < N ,并且没有有效的 i如果 N = 0 .

tf.pack(xs) 运算符旨在简单地解决这个问题。它需要一个列表 k N维张量(形状相同)并将它们打包成一个大小为k的N+1维张量在第 0 维。如果更换 tf.concat()tf.pack() ,您的程序将工作:
# features = tf.concat(0, [col1, col2, col3, col4])
features = tf.pack([col1, col2, col3, col4])

with tf.Session() as sess:
  # Start populating the filename queue.
  # ...

关于tensorflow - 我从这个 TensorFlow 的 csv 阅读器中遗漏了什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33686464/

相关文章:

python - 加权分类交叉熵语义分割

python - tf.Estimator.train 抛出 as_list() 未在未知 TensorShape 上定义

python - Windows10 上的 cross_val_score,并行计算错误

machine-learning - 类型错误 : Value passed to parameter 'input' has DataType string not in list of allowed values: int32, int64、complex64、float32、float64、bool、int8

python - 将重新缩放层(或与此相关的任何层)添加到经过训练的 tensorflow keras 模型

python - Tensorflow如何检查张量行是否只有零?

python - 训练模型后评估建议

python - session 和并行在 TF2.0 中如何工作?

python - 使用 Tensorflow 中具有多个 .csv 的大型数据集的时间序列数据的 LSTM 输入管道

tensorflow - CuDNNLSTM : Failed to call ThenRnnForward