Java伪尾调用递归产生更好的性能

标签 java performance data-structures jvm

我只是编写一些简单的实用程序来计算 linkedList 的长度,这样 linkedList 就不会托管其大小/长度的“内部”计数器。考虑到这一点,我有 3 个简单的方法:

  1. 迭代linkedList,直到到达“结束”
  2. 递归计算长度
  3. 递归计算长度,无需将控制权返回给调用函数(使用某种尾部调用递归)

下面是一些捕获这 3 种情况的代码:

// 1. iterative approach
public static <T> int getLengthIteratively(LinkedList<T> ll) {

    int length = 0;
    for (Node<T> ptr = ll.getHead(); ptr != null; ptr = ptr.getNext()) {
        length++;
    }

    return length;
}

// 2. recursive approach
public static <T> int getLengthRecursively(LinkedList<T> ll) {
    return getLengthRecursively(ll.getHead());
}

private static <T> int getLengthRecursively(Node<T> ptr) {

    if (ptr == null) {
        return 0;
    } else {
        return 1 + getLengthRecursively(ptr.getNext());
    }
}

// 3. Pseudo tail-recursive approach
public static <T> int getLengthWithFakeTailRecursion(LinkedList<T> ll) {
    return getLengthWithFakeTailRecursion(ll.getHead());
}

private static <T> int getLengthWithFakeTailRecursion(Node<T> ptr) {
    return getLengthWithFakeTailRecursion(ptr, 0);
}

private static <T> int getLengthWithFakeTailRecursion(Node<T> ptr, int result) {
    if (ptr == null) {
        return result;
    } else {
        return getLengthWithFakeTailRecursion(ptr.getNext(), result + 1);
    }
}

现在我知道 JVM 不支持开箱即用的尾递归,但是当我运行一些具有约 10k 节点的字符串 linkedList 的简单测试时,我注意到 getLengthWithFakeTailRecursion 始终优于 getLengthRecursively 方法(约 40%)。增量是否仅归因于案例#2 的控制权被传回每个节点并且我们被迫遍历所有堆栈帧这一事实?

编辑:这是我用来检查性能数字的简单测试:

public class LengthCheckerTest {

@Test
public void testLengthChecking() {

    LinkedList<String> ll = new LinkedList<String>();
    int sizeOfList = 12000;
    // int sizeOfList = 100000; // Danger: This causes a stackOverflow in recursive methods!
    for (int i = 1; i <= sizeOfList; i++) {
        ll.addNode(String.valueOf(i));
    }

    long currTime = System.nanoTime();
    Assert.assertEquals(sizeOfList, LengthChecker.getLengthIteratively(ll));
    long totalTime = System.nanoTime() - currTime;
    System.out.println("totalTime taken with iterative approach: " + (totalTime / 1000) + "ms");

    currTime = System.nanoTime();
    Assert.assertEquals(sizeOfList, LengthChecker.getLengthRecursively(ll));
    totalTime = System.nanoTime() - currTime;
    System.out.println("totalTime taken with recursive approach: " + (totalTime / 1000) + "ms");

    // Interestingly, the fakeTailRecursion always runs faster than the vanillaRecursion
    // TODO: Look into whether stack-frame collapsing has anything to do with this
    currTime = System.nanoTime();
    Assert.assertEquals(sizeOfList, LengthChecker.getLengthWithFakeTailRecursion(ll));
    totalTime = System.nanoTime() - currTime;
    System.out.println("totalTime taken with fake TCR approach: " + (totalTime / 1000) + "ms");
}
}

最佳答案

您的基准测试方法存在缺陷。您在同一个 JVM 中执行所有三个测试,因此它们的位置并不相同。当执行假尾测试时,LinkedListNode 类已经经过 JIT 编译,因此运行速度更快。您可以更改测试顺序,您会看到不同的数字。每个测试都应该在单独的 JVM 中执行。

我们写简单的JMH microbenchmark对于您的情况:

import java.util.concurrent.TimeUnit;

import org.openjdk.jmh.infra.Blackhole;
import org.openjdk.jmh.annotations.*;

// 5 warm-up iterations, 500 ms each, then 10 measurement iterations 500 ms each
// repeat everything three times (with JVM restart)
// output average time in microseconds
@Warmup(iterations = 5, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@Fork(3)
@State(Scope.Benchmark)
public class ListTest {
    // You did not supply Node and LinkedList implementation
    // but I assume they look like this
    static class Node<T> {
        final T value;
        Node<T> next;

        public Node(T val) {value = val;}
        public void add(Node<T> n) {next = n;}

        public Node<T> getNext() {return next;}
    }

    static class LinkedList<T> {
        Node<T> head;

        public void setHead(Node<T> h) {head = h;}
        public Node<T> getHead() {return head;}
    }

    // Code from your question follows

    // 1. iterative approach
    public static <T> int getLengthIteratively(LinkedList<T> ll) {

        int length = 0;
        for (Node<T> ptr = ll.getHead(); ptr != null; ptr = ptr.getNext()) {
            length++;
        }

        return length;
    }

    // 2. recursive approach
    public static <T> int getLengthRecursively(LinkedList<T> ll) {
        return getLengthRecursively(ll.getHead());
    }

    private static <T> int getLengthRecursively(Node<T> ptr) {

        if (ptr == null) {
            return 0;
        } else {
            return 1 + getLengthRecursively(ptr.getNext());
        }
    }

    // 3. Pseudo tail-recursive approach
    public static <T> int getLengthWithFakeTailRecursion(LinkedList<T> ll) {
        return getLengthWithFakeTailRecursion(ll.getHead());
    }

    private static <T> int getLengthWithFakeTailRecursion(Node<T> ptr) {
        return getLengthWithFakeTailRecursion(ptr, 0);
    }

    private static <T> int getLengthWithFakeTailRecursion(Node<T> ptr, int result) {
        if (ptr == null) {
            return result;
        } else {
            return getLengthWithFakeTailRecursion(ptr.getNext(), result + 1);
        }
    }

    // Benchmarking code

    // Measure for different list length        
    @Param({"10", "100", "1000", "10000"})
    int n;
    LinkedList<Integer> list;

    @Setup    
    public void setup() {
        list = new LinkedList<>();
        Node<Integer> cur = new Node<>(0);
        list.setHead(cur);
        for(int i=1; i<n; i++) {
            Node<Integer> next = new Node<>(i);
            cur.add(next);
            cur = next;
        }
    }

    // Do not forget to return result to the caller, so it's not optimized out
    @Benchmark    
    public int testIteratively() {
        return getLengthIteratively(list);
    }

    @Benchmark    
    public int testRecursively() {
        return getLengthRecursively(list);
    }

    @Benchmark    
    public int testRecursivelyFakeTail() {
        return getLengthWithFakeTailRecursion(list);
    }
}

这是我的机器上的结果(x64 Win7,Java 8u71)

Benchmark                           (n)  Mode  Cnt   Score    Error  Units
ListTest.testIteratively             10  avgt   30   0,009 ±  0,001  us/op
ListTest.testIteratively            100  avgt   30   0,156 ±  0,001  us/op
ListTest.testIteratively           1000  avgt   30   2,248 ±  0,036  us/op
ListTest.testIteratively          10000  avgt   30  26,416 ±  0,590  us/op
ListTest.testRecursively             10  avgt   30   0,014 ±  0,001  us/op
ListTest.testRecursively            100  avgt   30   0,191 ±  0,003  us/op
ListTest.testRecursively           1000  avgt   30   3,599 ±  0,031  us/op
ListTest.testRecursively          10000  avgt   30  40,071 ±  0,328  us/op
ListTest.testRecursivelyFakeTail     10  avgt   30   0,015 ±  0,001  us/op
ListTest.testRecursivelyFakeTail    100  avgt   30   0,190 ±  0,002  us/op
ListTest.testRecursivelyFakeTail   1000  avgt   30   3,609 ±  0,044  us/op
ListTest.testRecursivelyFakeTail  10000  avgt   30  41,534 ±  1,186  us/op

如您所见,假尾速度与简单递归速度相同(在误差范围内),但比迭代方法慢 20-60%。所以你的结果没有被重现。

如果您实际上想要的不是稳态测量结果,而是单次测量结果(无需预热),您可以使用以下选项启动相同的基准测试:-ss -wi 0 -i 1 -f 10 。结果如下:

Benchmark                           (n)  Mode  Cnt    Score     Error  Units
ListTest.testIteratively             10    ss   10   16,095 ±   0,831  us/op
ListTest.testIteratively            100    ss   10   19,780 ±   6,440  us/op
ListTest.testIteratively           1000    ss   10   74,316 ±  26,434  us/op
ListTest.testIteratively          10000    ss   10  366,496 ±  42,299  us/op
ListTest.testRecursively             10    ss   10   19,594 ±   7,084  us/op
ListTest.testRecursively            100    ss   10   21,973 ±   0,701  us/op
ListTest.testRecursively           1000    ss   10  165,007 ±  54,915  us/op
ListTest.testRecursively          10000    ss   10  563,739 ±  74,908  us/op
ListTest.testRecursivelyFakeTail     10    ss   10   19,454 ±   4,523  us/op
ListTest.testRecursivelyFakeTail    100    ss   10   25,518 ±  11,802  us/op
ListTest.testRecursivelyFakeTail   1000    ss   10  158,336 ±  43,646  us/op
ListTest.testRecursivelyFakeTail  10000    ss   10  755,384 ± 232,940  us/op

正如您所看到的,第一次启动比后续启动慢很多倍。而且你的结果仍然没有重现。我观察到,对于 n = 10000testRecursivelyFakeTail 实际上较慢(但在预热后,它达到了与 testRecursively 相同的峰值速度。

关于Java伪尾调用递归产生更好的性能,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35423375/

相关文章:

用于检查已签名 JAR 的 Java 实用程序

java - 如何在现有类型中声明或定义通用 xsd 类型。

带有ant的java版本控制系统

python - 纯python的数据类型转换比numpy快

language-agnostic - 有哪些鲜为人知但有用的数据结构?

java - 有某种方法可以制作子菜单项吗?

java - JVM 垃圾收集 - 跟踪年轻一代中的 Activity 对象

algorithm - 用于设计具有高效插入、删除和最高值检索的缓存的数据结构

python - 寻找到达数组末尾的最小成本

Scala - 尝试变换、匹配和映射方法