c++ - 避免计算图中的虚函数调用

标签 c++ performance virtual-functions numerical-computing

我正在使用 DAG(有向无环图)来表示和评估表达式;每个节点代表一个操作(+-/*、累积等),整个表达式的评估是通过按拓扑排序顺序依次评估每个节点来实现的。每个节点继承一个基类RefNode并根据它代表的运算符实现一个虚函数evaluate。 Node 类在代表运算符的仿函数上模板化。节点评估顺序在 vector<RefNode*> 中维护。与 ->evaluate()对每个元素的调用。

一些快速分析显示虚拟 evaluate将加法节点减慢 2 倍 [1],无论是开销还是破坏分支预测。

第一步将类型信息编码为整数,并使用 static_cast因此。这确实有帮助,但它很笨重,我不想在代码的热门部分跳来跳去。

struct RefNode {
    double output;
    inline virtual void evaluate(){}
};

template<class T>
struct Node : RefNode {
    double* inputs[NODE_INPUT_BUFFER_LENGTH];
    T evaluator;
    inline void evaluate(){ evaluator(inputs, output); }
};

struct Add {
    inline void operator()(double** inputs, double &output)
    {
        output=*inputs[0]+*inputs[1];
    }
};

评估可能如下所示:

Node<Add>* node_1 = ...
Node<Add>* node_2 = ...
std::vector<RefNode*> eval_vector;

eval_vector.push_back(node_1);
eval_vector.push_back(node_2);

for (auto&& n : eval_vector) {
    n->evaluate();
}

我有以下问题,请记住性能至关重要:

  1. 在这种情况下如何避免虚函数?
  2. 如果不是,我如何改变表示表达式图的方式以支持多个操作,其中一些操作必须保持状态,并避免虚函数调用。
  3. Tensorflow/Theano 等其他框架如何表示计算图?

[1] 我的系统上的单个加法操作使用虚函数需要大约 2.3ns,没有虚函数需要 1.1ns。虽然这很小,但整个计算图主要是加法节点,因此可以节省很大一部分时间。

最佳答案

如评论中所述,您需要在编译时了解图表才能删除虚拟分派(dispatch)。为此,您只需要使用 std::tuple:

auto eval_vector = std::make_tuple(
    Node<Add>{ ... },
    Node<Add>{ ... },
    ...
);

那么,只需要去掉virtualoverride关键字,去掉基类中的空函数即可。

你会发现基于范围的for循环还不支持元组。要对其进行迭代,您将需要该函数:

template<typename T, typename F, std::size_t... S>
void for_tuple(std::index_sequence<S...>, T&& tuple, F&& function) {
    int unpack[] = {(static_cast<void>(
        function(std::get<S>(std::forward<T>(tuple))
    ), 0)..., 0};
    static_cast<void>(unpack);
}

template<typename T, typename F>
void for_tuple(T&& tuple, F&& function) {
    constexpr std::size_t N = std::tuple_size<std::remove_reference_t<T>>::value;
    for_tuple(std::make_index_sequence<N>{}, std::forward<T>(tuple), std::forward<F>(function));
}

然后您可以像这样迭代您的元组:

for_tuple(eval_vector, [](auto&& node){
    node.evaluate();
});

关于c++ - 避免计算图中的虚函数调用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42487603/

相关文章:

c++ - 转换为未知类型

c++ - 为什么我们需要 C++ 中的虚函数?

C++ 使用虚函数

c++ - 在可简单复制的结构之间使用类型双关语会有多邪恶?

c++ - 在 Ubuntu 上使用 GLFW 设置 OpenGL NetBeans 项目

performance - grails中的 transient 属性会影响应用程序的内存使用吗?

c++ - 知道为什么 QHash 和 QMap 返回 const T 而不是 const T& 吗?

c - 变异数组中最小值及其偏移量的数据结构

c++ - vptr 和 vtable 如何在下面的虚拟相关代码中工作?

c++ - 代码运行后立即崩溃