我的 tensorflow 应用程序因以下错误而崩溃:
Inputs to operation linear/linear_model/weighted_sum_no_bias of type AddN must have the same size and shape: Input 0: [3,1] != input 1: [9,1]
如果有人能指出我的根本原因,我将不胜感激。
我有一个 tfrecord
文件,其中包含以下记录:
features {
feature {
key: "_label"
value {
float_list {
value: 1.0
}
}
}
feature {
key: "category"
value {
bytes_list {
value: "14"
value: "25"
value: "29"
}
}
}
feature {
key: "demo"
value {
bytes_list {
value: "gender:male"
value: "first_name:baerwulf52"
value: "country:us"
value: "city:manlius"
value: "region:us_ny"
value: "language:en"
value: "signup_hour_of_day:1"
value: "signup_day_of_week:3"
value: "signup_month_of_year:1"
}
}
}
}
我的规范如下
{
'category': VarLenFeature(dtype=tf.string),
'_label': FixedLenFeature(shape=(1,), dtype=tf.float32, default_value=None),
'demo': VarLenFeature(dtype=tf.string)
}
还有我的 tensorflow 代码:
category = tf.feature_column.categorical_column_with_vocabulary_list(key="category", vocabulary_list=["null", "14", "25", "29"],
demo = tf.feature_column.categorical_column_with_vocabulary_list(key="demo", vocabulary_list=["gender:male", "first_name:baerwulf52",
"country:us", "city:manlius", "region:us_ny",
"language:en", "signup_hour_of_day:1",
"signup_day_of_week:3",
"signup_month_of_year:1"])
feature_columns = [category, demo]
def get_input_fn(dataset):
def _fn():
iterator = dataset.make_one_shot_iterator()
next_elem = iterator.get_next()
ex = tf.parse_single_example(next_elem, features=spec)
label = ex.pop('_label')
return ex, label
return _fn
model = tf.estimator.LinearClassifier(
feature_columns=feature_columns,
model_dir=fp("model")
)
model.train(input_fn=get_input_fn(train_dataset), steps=100)
最佳答案
问题似乎是因为 LinearClassifier.train 方法需要批量输入并且代码正在调用:
ex = tf.parse_single_example(next_elem, features=spec)
如果您将 input_fn 更改为
,这将起作用iterator = dataset.batch(2).make_one_shot_iterator()
next_batch = iterator.get_next()
ex = tf.parse_example(next_batch, features=spec)
关于python - AddN 必须与 tf.estimator.LinearClassifier 具有相同的大小和形状,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46816469/