python - Tensorflow 在多数投票上表现缓慢

标签 python tensorflow machine-learning

我在 Tensorflow 1.10 上实现了多数投票算法(计算不同分类器的预测),对于预测大小为 1000 的数据集(MNIST)非常慢(10 个分类器需要 3 个小时以上)。根据我的猜测,这是因为在我的代码上多次调用 session.run() ,但我如何优化它?

def majority_voting(session, x, y):
    votes = []
    for i in range(number_of_ensemble_modules):
        # run the training
        feature_extractor = iterators[i][3]
        input, label = feature_extractor(x, y)
        transformed_x = session.run(input)
        ensemble_prediction = nn_models[0][i][0][3]
        prediction = session.run(ensemble_prediction, feed_dict={X: transformed_x, Y: y})
        votes.append(prediction[0])
    nearest_k_y, idx, vote = tf.unique_with_counts(tf.convert_to_tensor(votes, tf.int64))
    majority = tf.argmax(vote)
    predict_res = tf.gather(nearest_k_y, majority)
    return predict_res


def calculate_ensemble_accuracy():
    accuracy = 0
    for j in range(voting_iterations):
        accuracy += 0
        (features, labels) = session.run(next_element)
        vote = majority_voting(session, features, labels)
        correct_label = session.run(tf.argmax(labels, axis=1))
        if vote == correct_label[0]:
            accuracy += 1
    return accuracy

最佳答案

一些可能解决您的问题的提示。

1-在创建 tensorflow 图之前进行特征提取。例如,如果您创建 TfIDF 特征向量,您可以在预处理步骤中执行此操作,并保存 numpy 作为图形的输入。

input, label = feature_extractor(x, y)

2-删除不必要的session.run()。例如,当您调用 Optimizer 时,它会自动调用 x_transformed。

transformed_x = session.run(input)

3-以更好的方式使用tf.data (Dataset API)。无需调用sess.run(next_element)。因为 next_element 是图表的一部分。

关于python - Tensorflow 在多数投票上表现缓慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51992098/

相关文章:

python - 摆脱python matplotlib条形图中的灰色背景

python - 如何使用 hmmlearn 在 Python 中运行隐马尔可夫模型?

tensorflow - 如何正确使用tf-coreml?

c++ - 如何为SVM训练OpenCV3形成数据

machine-learning - 如何更新tensorflow以支持tf.contrib?

machine-learning - 数据分区中的类标签

python - 无法接收消息 - Python 套接字

python - pybind11 从 C++ 修改 numpy 数组

python-3.x - tensorflow 类型错误: run() got multiple values for argument 'feed_dict'

python - 从大量图像(*.jpg)和标签(*.mat)制作 tensorflow 数据集