我在 Tensorflow 1.10 上实现了多数投票算法(计算不同分类器的预测),对于预测大小为 1000 的数据集(MNIST)非常慢(10 个分类器需要 3 个小时以上)。根据我的猜测,这是因为在我的代码上多次调用 session.run() ,但我如何优化它?
def majority_voting(session, x, y):
votes = []
for i in range(number_of_ensemble_modules):
# run the training
feature_extractor = iterators[i][3]
input, label = feature_extractor(x, y)
transformed_x = session.run(input)
ensemble_prediction = nn_models[0][i][0][3]
prediction = session.run(ensemble_prediction, feed_dict={X: transformed_x, Y: y})
votes.append(prediction[0])
nearest_k_y, idx, vote = tf.unique_with_counts(tf.convert_to_tensor(votes, tf.int64))
majority = tf.argmax(vote)
predict_res = tf.gather(nearest_k_y, majority)
return predict_res
def calculate_ensemble_accuracy():
accuracy = 0
for j in range(voting_iterations):
accuracy += 0
(features, labels) = session.run(next_element)
vote = majority_voting(session, features, labels)
correct_label = session.run(tf.argmax(labels, axis=1))
if vote == correct_label[0]:
accuracy += 1
return accuracy
最佳答案
一些可能解决您的问题的提示。
1-在创建 tensorflow 图
之前进行特征提取。例如,如果您创建 TfIDF
特征向量,您可以在预处理步骤中执行此操作,并保存 numpy 作为图形的输入。
input, label = feature_extractor(x, y)
2-删除不必要的session.run()
。例如,当您调用 Optimizer 时,它会自动调用 x_transformed。
transformed_x = session.run(input)
3-以更好的方式使用tf.data
(Dataset API)。无需调用sess.run(next_element)
。因为 next_element
是图表的一部分。
关于python - Tensorflow 在多数投票上表现缓慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51992098/