performance - 在Haskell中执行恒定空间嵌套循环的正确方法是什么?

标签 performance loops haskell

在Haskell中,有两种显而易见的“惯用”方式来执行嵌套循环:使用列表monad或使用forM_代替传统的fors。我设置了一个基准以确定是否将它们编译为紧密循环:

import Control.Monad.Loop
import Control.Monad.Primitive
import Control.Monad
import Control.Monad.IO.Class
import qualified Data.Vector.Unboxed.Mutable as MV
import qualified Data.Vector.Unboxed as V

times = 100000
side  = 100

-- Using `forM_` to replace traditional fors
test_a mvec = 
    forM_ [0..times-1] $ \ n -> do
        forM_ [0..side-1] $ \ y -> do
            forM_ [0..side-1] $ \ x -> do
                MV.write mvec (y*side+x) 1

-- Using the list monad to replace traditional forms
test_b mvec = sequence_ $ do
    n <- [0..times-1]
    y <- [0..side-1]
    x <- [0..side-1]
    return $ MV.write mvec (y*side+x) 1

main = do
    let vec = V.generate (side*side) (const 0)
    mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    -- test_a mvec
    -- test_b mvec
    vec' <- V.unsafeFreeze mvec :: IO (V.Vector Int)
    print $ V.sum vec'

该测试创建一个100x100的 vector ,使用嵌套循环将1写入每个索引,并重复100k次。仅使用ghc -O2 test.hs -o test(ghc版本7.8.4)进行编译,结果为:3.853s版本为 forM_ 10.460s list monad 。为了提供引用,我还使用JavaScript编写了该测试程序:
var side  = 100;
var times = 100000;
var vec   = [];

for (var i=0; i<side*side; ++i)
    vec.push(0);

for (var n=0; n<times; ++n)
    for (var y=0; y<side; ++y)
        for (var x=0; x<side; ++x)
            vec[x+y*side] = 1;

var s = 0;
for (var i=0; i<side*side; ++i)
    s += vec[i];

console.log(s);

这个等效的JavaScript程序需要 1s 来完成,击败了Haskell的未装箱 vector ,这是不寻常的,这表明Haskell不在恒定空间中运行循环,而是进行分配。然后,我找到了一个声称提供类型保证的紧密循环Control.Monad.Loop的库:
-- Using `for` from Control.Monad.Loop
test_c mvec = exec_ $ do
    n <- for 0 (< times) (+ 1)
    x <- for 0 (< side) (+ 1)
    y <- for 0 (< side) (+ 1)
    liftIO (MV.write mvec (y*side+x) 1)

可以在 1s 中运行。该库不是很常用,而且离惯用语言还很远,所以是什么来快速进行恒定空间二维计算的惯用方法? (请注意,这不是REPA的情况,因为我想在网格上执行任意IO操作。)

最佳答案

用GHC编写紧密的变异代码有时会很棘手。我将写一些不同的东西,可能以比我更喜欢的方式杂乱无章。

对于初学者,无论如何我们都应该使用GHC 7.10,因为otherwiseforM_和list monad解决方案永远不会融合。

另外,我将MV.write替换为MV.unsafeWrite,部分是因为它更快,但更重要的是,它减少了最终Core中的困惑情况。从现在开始,运行时统计信息将使用unsafeWrite引用代码。

令人恐惧的漂浮

即使在GHC 7.10中,我们也应该首先注意到所有这些[0..times-1][0..side-1]表达式,因为如果不采取必要步骤,它们每次都会破坏性能。问题在于它们是恒定范围,并且-ffull-laziness(默认在-O上启用)将它们 float 到顶层。这样可以防止列表融合,并且在Int#范围内进行迭代比在盒装的Int -s列表中进行迭代便宜,因此这是一个非常糟糕的优化。

让我们以秒为单位查看一些运行时,以获取未更改的代码(除了使用unsafeWrite之外)。使用ghc -O2 -fllvm,我将+RTS -s用于计时。

test_a: 1.6
test_b: 6.2
test_c: 0.6

为了查看GHC Core,我使用了ghc -O2 -ddump-simpl -dsuppress-all -dno-suppress-type-signatures

test_a的情况下,[0..99]范围被取消:
main4 :: [Int]
main4 = eftInt 0 99 -- means "enumFromTo" for Int.

尽管最外面的[0..9999]循环融合到了尾递归帮助器中:
letrec {
          a3_s7xL :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a3_s7xL =
            \ (x_X5zl :: Int#) (s1_X4QY :: State# RealWorld) ->
              case a2_s7xF 0 s1_X4QY of _ { (# ipv2_a4NA, ipv3_a4NB #) ->
              case x_X5zl of wild_X1S {
                __DEFAULT -> a3_s7xL (+# wild_X1S 1) ipv2_a4NA;
                99999 -> (# ipv2_a4NA, () #)
              }
              }; }

test_b的情况下,再次仅提升[0..99]。但是,test_b慢得多,因为它必须构建和排序实际的[IO ()]列表。至少GHC足够明智,只为两个内部循环构建一个[IO ()],然后对10000进行排序。
 let {
          lvl7_s4M5 :: [IO ()]
          lvl7_s4M5 = -- omitted
        letrec {
          a2_s7Av :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a2_s7Av =
            \ (x_a5xi :: Int#) (eta_B1 :: State# RealWorld) ->
              letrec {
                a3_s7Au
                  :: [IO ()] -> State# RealWorld -> (# State# RealWorld, () #)
                a3_s7Au =
                  \ (ds_a4Nu :: [IO ()]) (eta1_X1c :: State# RealWorld) ->
                    case ds_a4Nu of _ {
                      [] ->
                        case x_a5xi of wild1_X1y {
                          __DEFAULT -> a2_s7Av (+# wild1_X1y 1) eta1_X1c;
                          99999 -> (# eta1_X1c, () #)
                        };
                      : y_a4Nz ys_a4NA ->
                        case (y_a4Nz `cast` ...) eta1_X1c
                        of _ { (# ipv2_a4Nf, ipv3_a4Ng #) ->
                        a3_s7Au ys_a4NA ipv2_a4Nf
                        }
                    }; } in
              a3_s7Au lvl7_s4M5 eta_B1; } in
-- omitted

我们该如何补救?我们可以用{-# OPTIONS_GHC -fno-full-laziness #-}来解决问题。在我们的情况下,这确实有很大帮助:
test_a: 0.5
test_b: 0.48
test_c: 0.5

另外,我们可以摆弄INLINE编译指示。显然,在让 float 完成之后内联函数可以保持良好的性能。我发现,即使没有编译指示,GHC也会内联我们的测试功能,但是显式编译指示只会使其在 float 后才内联。例如,在没有-fno-full-laziness的情况下,这将导致良好的性能:
test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
{-# INLINE test_a #-}

但是过早内联会导致性能下降:
test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
{-# INLINE [~2] test_a #-} -- "inline before the first phase please"

这种INLINE解决方案的问题在于,面对GHC的 float 冲击,它相当脆弱。例如,手动内联不能保留性能。以下代码很慢,因为与INLINE [~2]类似,它为GHC提供了 float 的机会:
main = do
    let vec = V.generate (side*side) (const 0)
    mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1    

那我们该怎么办呢?

首先,对于那些想编写高性能代码并知道自己在做什么的人,我认为使用-fno-full-laziness是一个完美可行的方法,甚至是更可取的选择。例如,它在 unordered-containers 中使用。有了它,我们可以更精确地控制共享,并且我们总是可以手动 float 或内联。

对于更常规的代码,我相信使用Control.Monad.Loop或提供该功能的任何其他软件包都没有错。许多Haskell用户并不依赖小型的“附带”库。我们还可以按照所需的一般性重新实现for。例如,以下各项与其他解决方案一样出色:
for :: Monad m => a -> (a -> Bool) -> (a -> a) -> (a -> m ()) -> m ()
for init while step body = go init where
  go !i | while i = body i >> go (step i)
  go i = return ()
{-# INLINE for #-}

在真正恒定的空间中循环

起初我对堆分配中的+RTS -s数据感到非常困惑。 test_a-fno-full-laziness一起分配,而且没有完全延迟的test_c分配,这些分配与times迭代次数成线性比例,但是具有完全延迟的test_b仅分配给 vector :
-- with -fno-full-laziness, no INLINE pragmas
test_a: 242,521,008 bytes
test_b: 121,008 bytes
test_c: 121,008 bytes -- but 240,120,984 with full laziness!

同样,在这种情况下,INLINEtest_c编译指示根本没有帮助。

我花了一些时间尝试在相关程序的Core中找到堆分配的迹象,但没有成功,直到实现使我震惊:GHC堆栈框架在堆上,包括主线程的框架以及正在执行的功能堆分配实际上在最多三个堆栈帧中运行三次嵌套循环。 +RTS -s注册的堆分配只是不断 pop 和推送堆栈帧。

对于以下代码,这在Core上非常明显:
{-# OPTIONS_GHC -fno-full-laziness #-}

-- ...

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
main = do
    let vec = V.generate (side*side) (const 0)
    mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    test_a mvec

我将其纳入其中。随时跳过。
main1 :: State# RealWorld -> (# State# RealWorld, () #)
main1 =
  \ (s_a5HK :: State# RealWorld) ->
    case divInt# 9223372036854775807 8 of ww4_a5vr { __DEFAULT ->

    -- start of vector creation ----------------------
    case tagToEnum# (># 10000 ww4_a5vr) of _ {
      False ->
        case newByteArray# 80000 (s_a5HK `cast` ...)
        of _ { (# ipv_a5fv, ipv1_a5fw #) ->
        letrec {
          $s$wa_s8jS
            :: Int#
               -> Int#
               -> State# (PrimState IO)
               -> (# State# (PrimState IO), Int #)
          $s$wa_s8jS =
            \ (sc_s8jO :: Int#)
              (sc1_s8jP :: Int#)
              (sc2_s8jR :: State# (PrimState IO)) ->
              case tagToEnum# (<# sc1_s8jP 10000) of _ {
                False -> (# sc2_s8jR, I# sc_s8jO #);
                True ->
                  case writeIntArray# ipv1_a5fw sc_s8jO 0 (sc2_s8jR `cast` ...)
                  of s'#_a5Gn { __DEFAULT ->
                  $s$wa_s8jS (+# sc_s8jO 1) (+# sc1_s8jP 1) (s'#_a5Gn `cast` ...)
                  }
              }; } in
        case $s$wa_s8jS 0 0 (ipv_a5fv `cast` ...)
        -- end of vector creation -------------------

        of _ { (# ipv6_a4Hv, ipv7_a4Hw #) ->
        letrec {
          a2_s7MJ :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a2_s7MJ =
            \ (x_a5Ho :: Int#) (eta_B1 :: State# RealWorld) ->
              letrec {
                a3_s7ME :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                a3_s7ME =
                  \ (x1_X5Id :: Int#) (eta1_XR :: State# RealWorld) ->
                    case ipv7_a4Hw of _ { I# dt4_a5x6 ->
                    case writeIntArray#
                           (ipv1_a5fw `cast` ...) (*# x1_X5Id 100) 1 (eta1_XR `cast` ...)
                    of s'#_a5Gn { __DEFAULT ->
                    letrec {
                      a4_s7Mz :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                      a4_s7Mz =
                        \ (x2_X5J8 :: Int#) (eta2_X1U :: State# RealWorld) ->
                          case writeIntArray#
                                 (ipv1_a5fw `cast` ...)
                                 (+# (*# x1_X5Id 100) x2_X5J8)
                                 1
                                 (eta2_X1U `cast` ...)
                          of s'#1_X5Hf { __DEFAULT ->
                          case x2_X5J8 of wild_X2o {
                            __DEFAULT -> a4_s7Mz (+# wild_X2o 1) (s'#1_X5Hf `cast` ...);
                            99 -> (# s'#1_X5Hf `cast` ..., () #)
                          }
                          }; } in
                    case a4_s7Mz 1 (s'#_a5Gn `cast` ...)
                    of _ { (# ipv2_a4QH, ipv3_a4QI #) ->
                    case x1_X5Id of wild_X1e {
                      __DEFAULT -> a3_s7ME (+# wild_X1e 1) ipv2_a4QH;
                      99 -> (# ipv2_a4QH, () #)
                    }
                    }
                    }
                    }; } in
              case a3_s7ME 0 eta_B1 of _ { (# ipv2_a4QH, ipv3_a4QI #) ->
              case x_a5Ho of wild_X1a {
                __DEFAULT -> a2_s7MJ (+# wild_X1a 1) ipv2_a4QH;
                99999 -> (# ipv2_a4QH, () #)
              }
              }; } in
        a2_s7MJ 0 (ipv6_a4Hv `cast` ...)
        }
        };
      True ->
        case error
               (unpackAppendCString#
                  "Primitive.basicUnsafeNew: length to large: "#
                  (case $wshowSignedInt 0 10000 ([])
                   of _ { (# ww5_a5wm, ww6_a5wn #) ->
                   : ww5_a5wm ww6_a5wn
                   }))
        of wild_00 {
        }
    }
    }

main :: IO ()
main = main1 `cast` ...

main2 :: State# RealWorld -> (# State# RealWorld, () #)
main2 = runMainIO1 (main1 `cast` ...)

main :: IO ()
main = main2 `cast` ...

我们还可以通过以下方式很好地演示帧的分配。让我们更改test_a:
test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-50] $ \ x -> -- change here
                MV.unsafeWrite mvec (y*side+x) 1

现在,堆分配保持完全相同,因为最内部的循环是尾递归的,并且使用单个帧。进行以下更改后,堆分配减少了一半(至124,921,008字节),因为我们推送并 pop 了一半的帧:
test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-50] $ \ y -> -- change here
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
test_btest_c(没有完全延迟)改为编译为在单个堆栈框架内使用嵌套case构造的代码,并遍历索引以查看应递增的索引。请参阅以下main的Core:
{-# LANGUAGE BangPatterns #-} -- later I'll talk about this
{-# OPTIONS_GHC -fno-full-laziness #-}

main = do
    let vec = V.generate (side*side) (const 0)
    !mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    test_c mvec

瞧:
main1 :: State# RealWorld -> (# State# RealWorld, () #)
main1 =
  \ (s_a5Iw :: State# RealWorld) ->
    case divInt# 9223372036854775807 8 of ww4_a5vT { __DEFAULT ->

    -- start of vector creation ----------------------
    case tagToEnum# (># 10000 ww4_a5vT) of _ {
      False ->
        case newByteArray# 80000 (s_a5Iw `cast` ...)
        of _ { (# ipv_a5g3, ipv1_a5g4 #) ->
        letrec {
          $s$wa_s8ji
            :: Int#
               -> Int#
               -> State# (PrimState IO)
               -> (# State# (PrimState IO), Int #)
          $s$wa_s8ji =
            \ (sc_s8je :: Int#)
              (sc1_s8jf :: Int#)
              (sc2_s8jh :: State# (PrimState IO)) ->
              case tagToEnum# (<# sc1_s8jf 10000) of _ {
                False -> (# sc2_s8jh, I# sc_s8je #);
                True ->
                  case writeIntArray# ipv1_a5g4 sc_s8je 0 (sc2_s8jh `cast` ...)
                  of s'#_a5GP { __DEFAULT ->
                  $s$wa_s8ji (+# sc_s8je 1) (+# sc1_s8jf 1) (s'#_a5GP `cast` ...)
                  }
              }; } in
        case $s$wa_s8ji 0 0 (ipv_a5g3 `cast` ...)
        of _ { (# ipv6_a4MX, ipv7_a4MY #) ->
        case ipv7_a4MY of _ { I# dt4_a5xy ->
        -- end of vector creation

        letrec {
          a2_s7Q6 :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a2_s7Q6 =
            \ (x_a5HT :: Int#) (eta_B1 :: State# RealWorld) ->
              letrec {
                a3_s7Q5 :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                a3_s7Q5 =
                  \ (x1_X5J9 :: Int#) (eta1_XP :: State# RealWorld) ->
                    letrec {
                      a4_s7MZ :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                      a4_s7MZ =
                        \ (x2_X5Jl :: Int#) (s1_X4Xb :: State# RealWorld) ->
                          case writeIntArray#
                                 (ipv1_a5g4 `cast` ...)
                                 (+# (*# x1_X5J9 100) x2_X5Jl)
                                 1
                                 (s1_X4Xb `cast` ...)
                          of s'#_a5GP { __DEFAULT ->

                          -- the interesting part! ------------------
                          case x2_X5Jl of wild_X1y {
                            __DEFAULT -> a4_s7MZ (+# wild_X1y 1) (s'#_a5GP `cast` ...);
                            99 ->
                              case x1_X5J9 of wild1_X1o {
                                __DEFAULT -> a3_s7Q5 (+# wild1_X1o 1) (s'#_a5GP `cast` ...);
                                99 ->
                                  case x_a5HT of wild2_X1c {
                                    __DEFAULT -> a2_s7Q6 (+# wild2_X1c 1) (s'#_a5GP `cast` ...);
                                    99999 -> (# s'#_a5GP `cast` ..., () #)
                                  }
                              }
                          }
                          }; } in
                    a4_s7MZ 0 eta1_XP; } in
              a3_s7Q5 0 eta_B1; } in
        a2_s7Q6 0 (ipv6_a4MX `cast` ...)
        }
        }
        };
      True ->
        case error
               (unpackAppendCString#
                  "Primitive.basicUnsafeNew: length to large: "#
                  (case $wshowSignedInt 0 10000 ([])
                   of _ { (# ww5_a5wO, ww6_a5wP #) ->
                   : ww5_a5wO ww6_a5wP
                   }))
        of wild_00 {
        }
    }
    }

main :: IO ()
main = main1 `cast` ...

main2 :: State# RealWorld -> (# State# RealWorld, () #)
main2 = runMainIO1 (main1 `cast` ...)

main :: IO ()
main = main2 `cast` ...

我不得不承认,我基本上不知道为什么有些代码避免创建堆栈框架,而有些则没有。我怀疑从“内部”向外进行内联会有所帮助,而快速检查告诉我Control.Monad.Loop使用CPS编码,这在这里可能是相关的,尽管Monad.Loop解决方案很容易 float ,而且我无法在短时间内确定从核心来看,为什么带有let float 的test_c无法在单个堆栈帧中运行。

现在,在单个堆栈框架中运行的性能优势很小。我们已经看到test_b仅比test_a快一点。我在回答中包含了这条弯路,因为我发现它很有启发性。

国家黑客和严格的约束

所谓的state hack使GHC积极地内联到IO和ST Action 中。我认为我应该在这里提到它,因为除了 float 之外,这是可能彻底破坏性能的另一件事。

状态hack通过优化-O启用,并且可能会逐渐降低程序速度。 Reid Barton的一个简单示例:
import Control.Monad
import Debug.Trace

expensive :: String -> String
expensive x = trace "$$$" x

main :: IO ()
main = do
  str <- fmap expensive getLine
  replicateM_ 3 $ print str

使用GHC-7.10.2时,它会在没有优化的情况下打印"$$$"一次,但使用-O2则打印三次。似乎在GHC-7.10中,我们无法通过-fno-state-hack(这是来自Reid Barton的链接票证的主题)摆脱这种行为。

严格的单子(monad)绑定(bind)可靠地摆脱了这个问题:
main :: IO ()
main = do
  !str <- fmap expensive getLine
  replicateM_ 3 $ print str

我认为在IO和ST中进行严格绑定(bind)是一个好习惯。而且我有一些经验(虽然不是确定的;我远不是GHC专家),如果我们使用-fno-full-laziness,则特别需要严格的绑定(bind)。显然,完全懒惰可以帮助消除因国家黑客入侵而导致的内联引入的某些工作重复;带有test_b且没有完全延迟的情况下,省略对!mvec <- V.unsafeThaw vec的严格绑定(bind)会导致速度稍有下降,并且Core输出非常丑陋。

关于performance - 在Haskell中执行恒定空间嵌套循环的正确方法是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32345095/

相关文章:

android - Azure 流式传输仅在 Android 应用程序内缓慢

python - Cython 字符串连接超慢;它还有什么不好的地方?

javascript - 闭包在循环中不能正常工作

javascript - 如何从嵌套结构中获取最高数字

haskell - 使用 GHC (+ LLVM) 将 GMP 静态链接到 Haskell 应用程序

用于简单碰撞检测的 Javascript 位图

python - 插入 mongodb (pymongo) 时的效率

记录 R 中循环的输出

haskell - 如何在 Haskell 中编写模式准引用器?

haskell - 通用量化和统一,一个例子