我正在尝试遵循 structured data models 的 Tensorflow 教程(我是初学者)在此过程中发生了一些变化。
我的目的是创建一个模型,向其提供数据(csv 格式),看起来像这样(该示例只有 2 个功能,但我想在弄清楚后对其进行扩展):
power_0,power_1,result
0.2,0.3,draw
0.8,0.1,win
0.3,0.1,draw
0.7,0.2,win
0.0,0.4,lose
我使用以下代码创建了模型:
def get_labels(df, label, mapping):
raw_y_true = df.pop(label)
y_true = np.zeros((len(raw_y_true)))
for i, raw_label in enumerate(raw_y_true):
y_true[i] = mapping[raw_label]
return y_true
tf.compat.v1.enable_eager_execution()
mapping_to_numbers = {'win': 0, 'draw': 1, 'lose': 2}
data_frame = pd.read_csv('data.csv')
data_frame.head()
train, test = train_test_split(data_frame, test_size=0.2)
train, val = train_test_split(train, test_size=0.2)
train_labels = np.array(get_labels(train, label='result', mapping=mapping_to_numbers))
val_labels = np.array(get_labels(val, label='result', mapping=mapping_to_numbers))
test_labels = np.array(get_labels(test, label='result', mapping=mapping_to_numbers))
train_features = np.array(train)
val_features = np.array(val)
test_features = np.array(test)
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(train_features.shape[-1],)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Dense(3, activation='sigmoid'),
])
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'],
run_eagerly=True)
epochs = 10
batch_size = 100
history = model.fit(
train_features,
train_labels,
epochs=epochs,
validation_data=(val_features, val_labels))
input_data_frame = pd.read_csv('input.csv')
input_data_frame.head()
input_data = np.array(input_data_frame)
print(model.predict(input_data))
input.csv 如下所示:
power_0,power_1
0.8,0.1
0.7,0.2
实际结果是:
[[0.00604381 0.00242573 0.00440606]
[0.01321151 0.00634229 0.01041476]]
我希望获得每个标签的概率(“赢”、“平局”和“输”),有人可以帮我吗?
提前致谢
最佳答案
在此行tf.keras.layers.Dense(3,activation='sigmoid')
中使用softmax激活。
关于python-3.x - Tensorflow 结构化数据 model.predict() 返回错误的概率,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58227659/