c++ - 在 C++ 中不使用 cmath 计算 n 次根的有效方法

标签 c++ algorithm math c++20 nth-root

如何在不使用 cmath 等的情况下有效计算数字的 n 次根,精确到至少 12 个正确的小数位?

我已经尝试自己解决这个问题了。我的想法是找到一个近似值,并使用牛顿法使近似值更加准确。

我实现了两种方法,一种使用二分搜索,另一种基于快速逆平方根算法。

#include <array>
#include <chrono>
#include <cmath>
#include <iostream>
#include <vector>

using std::chrono::steady_clock;
using std::chrono::duration;
using std::cout;
using std::vector;
float r = 0.0;

inline float power(float base, int exp) {
    if (not exp) {
        return 1.0;
    }
    if (exp < 0) {
        base = 1 / base;
        exp = -exp;
    }
    float p = 1.0;
    while (exp > 1) {
        if (exp % 2) {
            p = base * p;
        }
        base *= base;
        exp /= 2;
    }
    return base * p;
}

inline float nth_root(float base, int n) {
    float lo, hi, x, p, r, v;
    int n1 = n - 1;
    lo = 0;
    hi = base;
    for (int i = 0; i < 12; i++) {
        x = (lo + hi) / 2;
        p = power(x, n1);
        v = x * p - base;
        r = n * p;
        if (v <= 0) {
            lo = x;
        }
        else {
            hi = x;
        }
    }
    x = (lo + hi) / 2;
    r = 1.0 / n;
    for (int i = 0; i < 12; i++) {
        x = r * (n * x + base / power(x, n));
    }
    return x;
}

inline float fast_nth_root(float base, int n)
{
   uint32_t i = std::bit_cast<uint32_t>(base);
   float rn = 1.0 / n;
   i = 0x3F7A3BEA * rn * (n + 1) - i * rn;
   float x = std::bit_cast<float>(i);
   for (int j = 0; j < 6; j++) {
       x = x * (n + 1 - base * power(x, n)) * rn;
   }
   return 1.0 / x;
}

int main()
{
    vector<float> bases(256);
    vector<int> ns(256);
    float r256 = 1.0 / 256;
    for (int i = 0; i < 256; i++) {
        bases[i] = 1.0 + rand() % 16384 + (rand() % 256) * r256;
        ns[i] = 2 + rand() % 30;
    }
    auto start = steady_clock::now();
    for (int64_t i = 0; i < 1048576; i++) {
        r += nth_root(bases[i % 256], ns[i % 256]);
    }
    auto end = steady_clock::now();
    duration<double, std::nano> time = end - start;
    cout << "nth_root: " << time.count() / 1048576 << " nanoseconds\n";
    start = steady_clock::now();
    for (int64_t i = 0; i < 1048576; i++) {
        r += pow(bases[i % 256], 1.0 / ns[i % 256]);
    }
    end = steady_clock::now();
    time = end - start;
    cout << "pow: " << time.count() / 1048576 << " nanoseconds\n";
    start = steady_clock::now();
    for (int64_t i = 0; i < 1048576; i++) {
        r += fast_nth_root(bases[i % 256], ns[i % 256]);
    }
    end = steady_clock::now();
    time = end - start;
    cout << "fast_nth_root: " << time.count() / 1048576 << " nanoseconds\n";
}

编译:

g++.exe -Wall -fexceptions -fomit-frame-pointer -fexpensive-optimizations -flto -O3 -m64 --std=c++20 -march=native -ffast-math  -c D:\MyScript\CodeBlocks\testapp\main.cpp -o obj\Release\main.o
g++.exe  -o bin\Release\testapp.exe obj\Release\main.o  -O3 -flto -s -static-libstdc++ -static-libgcc -static -m64  
PS C:\Users\Xeni> D:\MyScript\CodeBlocks\testapp\bin\Release\testapp.exe
nth_root: 318.12 nanoseconds
pow: 104.77 nanoseconds
fast_nth_root: 53.4222 nanoseconds

正如预期的那样,第一种方法非常慢,但尽管第二种方法比库代码快,但可能不那么准确。

根据我在 Python 中的测试:

import random, struct

def root(num, p, lim, lin):
    lo = 0
    hi = num
    for _ in range(lim):
        x = (lo + hi) / 2
        po = x ** (p - 1)
        v = x * po - num
        r = p * po
        if v <= 0:
            lo = x
        else:
            hi = x

    x = (lo + hi) / 2
    r = 1 / p
    p -= 1
    for _ in range(lin):
        x = r * (p * x + num / x ** p)
    
    return x

stats = {}
for _ in range(64):
    n = random.randrange(1, 16384)
    p = random.randrange(2, 32)
    x = n ** (1 / p)
    i = 2
    j = 2
    b = 0
    while abs((y := root(n, p, i, j)) - x) > 1e-13:
        if i < 16 and b or i >= 16:
            j += 1
        else:
            i += 1
        b = not b

    stats[(n, p)] = (i, j, x, y)

def fast_nth_root(x: float, n: int, lim: int) -> float:
    t = int.from_bytes(struct.pack('>f', x), 'big')
    t = round(0x3F7A3BEA / n * (n + 1) - t / n)
    t = struct.unpack('>f', struct.pack('>i', t))[0]
    for _ in range(lim):
        t = t * (n + 1 - x * t ** n) / n
    
    return 1 / t


stats1 = {}
for _ in range(64):
    n = random.randrange(1, 16384)
    p = random.randrange(2, 32)
    x = n ** (1 / p)
    i = 2
    while abs((y := fast_nth_root(n, p, i)) - x) > 1e-13:
        i += 1

    stats1[(n, p)] = (i, x, y)
In [1434]: stats
Out[1434]:
{(7162, 13): (11, 10, 1.979434344458145, 1.979434344458145),
 (15510, 2): (5, 5, 124.53915047084591, 124.53915047084595),
 (3054, 25): (10, 9, 1.3784618821753079, 1.3784618821753076),
 (1601, 25): (9, 8, 1.3433081611539914, 1.3433081611540099),
 (3522, 21): (10, 9, 1.475348878994169, 1.475348878994169),
 (15107, 14): (12, 11, 1.9884410975573956, 1.9884410975573954),
 (15200, 16): (12, 11, 1.8254301706001883, 1.825430170600188),
 (1900, 15): (9, 8, 1.6541830766301984, 1.6541830766301981),
 (16145, 20): (12, 11, 1.6233116388185762, 1.623311638818576),
 (2580, 4): (7, 6, 7.126969930959522, 7.126969930959522),
 (1702, 27): (11, 10, 1.3172407839773748, 1.3172407839773748),
 (9875, 29): (13, 13, 1.3732280275029944, 1.3732280275029942),
 (15687, 15): (12, 11, 1.9041565885196923, 1.904156588519692),
 (5774, 16): (12, 11, 1.718273525571221, 1.718273525571221),
 (6186, 2): (5, 4, 78.65112840894274, 78.65112840894277),
 (4476, 23): (12, 11, 1.441233509663115, 1.441233509663115),
 (4161, 24): (12, 11, 1.4151416228042228, 1.4151416228042226),
 (16116, 13): (12, 11, 2.1068575543742956, 2.1068575543742956),
 (12380, 14): (11, 11, 1.9603661231094314, 1.9603661231094311),
 (9736, 19): (13, 12, 1.621491836717726, 1.6214918367177258),
 (8612, 26): (13, 12, 1.4169357394205302, 1.4169357394205302),
 (4586, 7): (9, 8, 3.334740217355978, 3.3347402173559777),
 (5232, 24): (12, 12, 1.428711330576587, 1.4287113305765868),
 (14698, 17): (12, 11, 1.7584613929955697, 1.7584613929955695),
 (4931, 13): (10, 9, 1.923410237452901, 1.9234102374529014),
 (7391, 4): (8, 7, 9.272050761175075, 9.272050761175075),
 (9949, 6): (9, 9, 4.637635073009885, 4.637635073009886),
 (4767, 18): (12, 11, 1.6008364077669808, 1.6008364077669806),
 (16318, 8): (11, 10, 3.3618889684623863, 3.3618889684623863),
 (7610, 28): (13, 12, 1.3760077520394016, 1.3760077520394014),
 (13632, 6): (10, 9, 4.887573066390476, 4.887573066390476),
 (8380, 21): (11, 11, 1.5375213098103222, 1.5375213098103224),
 (7247, 14): (11, 10, 1.88679879448582, 1.8867987944858202),
 (11343, 18): (13, 12, 1.6798196720467486, 1.6798196720467486),
 (6468, 17): (11, 10, 1.6755714114675964, 1.6755714114675964),
 (11801, 6): (10, 9, 4.771480299610415, 4.771480299610415),
 (441, 28): (9, 8, 1.2429230307022932, 1.2429230307022932),
 (15341, 14): (12, 11, 1.9906254301097404, 1.9906254301097404),
 (8501, 20): (11, 10, 1.5720758667453518, 1.5720758667453518),
 (2777, 19): (10, 10, 1.5178918732605664, 1.5178918732605664),
 (14842, 30): (14, 13, 1.3773672345540857, 1.3773672345540857),
 (6149, 28): (11, 10, 1.3655715058830975, 1.3655715058830973),
 (13374, 21): (12, 11, 1.5721306482454025, 1.5721306482454025),
 (9947, 30): (13, 13, 1.3591156205784112, 1.359115620578415),
 (14423, 16): (12, 11, 1.8194535606369682, 1.8194535606369682),
 (9341, 31): (13, 12, 1.3430036888404402, 1.3430036888404402),
 (14558, 7): (10, 9, 3.9330441035217714, 3.9330441035217714),
 (152, 16): (8, 7, 1.3688795144738382, 1.368879514473838),
 (13593, 18): (12, 11, 1.6967920812890351, 1.6967920812890351),
 (2834, 7): (8, 8, 3.113149507653915, 3.1131495076539637),
 (11545, 14): (11, 10, 1.950612466604169, 1.9506124666041689),
 (12416, 21): (12, 11, 1.566576147387506, 1.5665761473875057),
 (8998, 6): (9, 8, 4.560624662646501, 4.560624662646504),
 (5245, 27): (11, 10, 1.3733092699013858, 1.3733092699013856),
 (5693, 29): (11, 10, 1.3473937347050562, 1.3473937347050562),
 (3508, 26): (10, 10, 1.3688266297365765, 1.368826629736599),
 (16237, 9): (11, 10, 2.936526854011741, 2.936526854011741),
 (2911, 6): (8, 7, 3.778682197915392, 3.7786821979153924),
 (387, 22): (10, 9, 1.311061987204142, 1.311061987204142),
 (3324, 4): (7, 7, 7.593032412784651, 7.593032412784651),
 (15300, 22): (12, 11, 1.5495773040868939, 1.5495773040868939),
 (5469, 26): (11, 10, 1.3924053694535468, 1.392405369453547),
 (1195, 13): (8, 8, 1.7247279767538015, 1.724727976753898),
 (7998, 13): (11, 10, 1.9963162339549785, 1.9963162339549787)}

In [1435]: stats1
Out[1435]:
{(10882, 3): (4, 22.159990703206965, 22.159990703206965),
 (6673, 28): (6, 1.3695657835909767, 1.3695657835909767),
 (4803, 10): (4, 2.3342709144708604, 2.3342709144708604),
 (1802, 27): (5, 1.3200291160996098, 1.3200291160996294),
 (8380, 15): (4, 1.8262053100463662, 1.826205310046366),
 (12898, 21): (5, 1.569419919895426, 1.5694199198954262),
 (10227, 2): (4, 101.12863096077193, 101.12863096077193),
 (4857, 25): (5, 1.4042832772764529, 1.4042832772765208),
 (1351, 12): (4, 1.8234251715932501, 1.8234251715932501),
 (10180, 16): (3, 1.7802632882832108, 1.7802632882832108),
 (6948, 28): (6, 1.37154252932099, 1.3715425293209902),
 (13901, 10): (4, 2.5959994991170756, 2.5959994991171),
 (7513, 21): (5, 1.529545998990047, 1.529545998990047),
 (7902, 18): (4, 1.6464211857566509, 1.6464211857566777),
 (3277, 31): (6, 1.2983819395882321, 1.2983819395882321),
 (3499, 10): (4, 2.261488555258189, 2.2614885552581887),
 (15234, 30): (6, 1.3785646304878574, 1.3785646304878574),
 (5739, 13): (5, 1.9459928850146935, 1.9459928850146935),
 (1823, 24): (5, 1.3673072329614473, 1.367307232961453),
 (15105, 16): (4, 1.8247150144295399, 1.8247150144295399),
 (16215, 12): (3, 2.2429852247131876, 2.2429852247131876),
 (15844, 20): (5, 1.6217848596088165, 1.6217848596088162),
 (15677, 26): (5, 1.4499608216310222, 1.4499608216310922),
 (11839, 22): (5, 1.5316187797414815, 1.5316187797414818),
 (10163, 26): (6, 1.4259891723870766, 1.4259891723870766),
 (1550, 18): (5, 1.5039751132184096, 1.5039751132184096),
 (15194, 5): (4, 6.8601628355219795, 6.860162835521979),
 (15612, 24): (5, 1.4952969217858556, 1.495296921785857),
 (9469, 12): (4, 2.1446611088057317, 2.144661108805732),
 (4030, 20): (5, 1.5144859625611744, 1.5144859625611742),
 (11729, 3): (4, 22.720627875592783, 22.72062787559279),
 (12709, 29): (6, 1.3852274277113097, 1.3852274277113097),
 (12263, 31): (6, 1.354846884624788, 1.354846884624788),
 (6372, 9): (4, 2.6466548424778766, 2.6466548424779086),
 (7119, 5): (4, 5.8949998528103436, 5.8949998528103436),
 (10737, 27): (6, 1.4102365332846034, 1.4102365332846034),
 (2231, 11): (4, 2.01562182326949, 2.0156218232694902),
 (412, 9): (4, 1.952289125066342, 1.9522891250663446),
 (8417, 5): (4, 6.095810609109851, 6.095810609109851),
 (6759, 31): (6, 1.3290600231725829, 1.3290600231725826),
 (2207, 23): (5, 1.3975994148304356, 1.3975994148304371),
 (4755, 16): (4, 1.6975473914419268, 1.697547391441927),
 (7978, 13): (4, 1.995931787083301, 1.9959317870833602),
 (14957, 19): (4, 1.658550339552885, 1.658550339552935),
 (745, 28): (5, 1.2664178106847905, 1.2664178106847914),
 (2696, 15): (4, 1.6932249497186587, 1.6932249497186587),
 (5484, 7): (4, 3.421029175738748, 3.421029175738748),
 (15410, 10): (4, 2.6228911266357167, 2.6228911266357264),
 (315, 10): (3, 1.7775877772276876, 1.7775877772276876),
 (9252, 13): (4, 2.0188081438649, 2.0188081438649035),
 (1562, 27): (5, 1.313059731029047, 1.3130597310290748),
 (9803, 8): (4, 3.1544225980493534, 3.154422598049354),
 (14443, 16): (4, 1.8196111450479415, 1.8196111450479413),
 (5033, 23): (5, 1.4486017214956801, 1.4486017214956872),
 (16175, 2): (4, 127.18097341976905, 127.18097341976903),
 (15125, 5): (4, 6.853920721113647, 6.853920721113645),
 (16292, 19): (4, 1.6660301757087943, 1.666030175708797),
 (11486, 20): (5, 1.5959101637580406, 1.5959101637580406),
 (13824, 7): (4, 3.9040835527337374, 3.9040835527337383),
 (8604, 3): (4, 20.491172086350442, 20.491172086350442),
 (1225, 3): (4, 10.699874805650794, 10.699874805650795),
 (9163, 21): (5, 1.54407524840236, 1.5440752484023603),
 (7833, 21): (5, 1.53258704058841, 1.53258704058841),
 (8425, 24): (5, 1.4573551932369144, 1.4573551932369178)}

平均而言,第一种方法需要大约 12 次二分搜索迭代和 12 次牛顿法迭代才能使误差低于 10-13,而第二种方法则需要 6 次牛顿法迭代方法获得相同的精度。

有没有办法让代码在相同迭代次数下运行得更快,或者有办法加快所涉及数学的收敛速度?


这不是作业。这是一个 self 施加的编程挑战。

最佳答案

您没有说明您的函数应覆盖哪些值范围,也没有说明这些值的统计分布,因此我们必须建议一种通用方法。

该策略是首先通过研究数量级来找到牛顿迭代的良好初始近似值。

如果您可以访问浮点表示形式的指数(通过破解二进制表示形式),一个好的起始值是 √2/2 乘以 2^(exponent/n)。我们选择 √2/2 而不是 1 来以 [1/2, 1) 为中心。为了提高效率,您可以为所有指数和所有根阶预先计算这些常数。不管怎样,为了节省空间,最好将 n 上的指数分解为整数商和余数。

如果指数不可用,那么您可以从 1 开始,通过连续加倍(或减半)来搜​​索它(因此 1=2^0、2=2^1、4=2^2、8=2 ^3...)。这相当于指数之间的线性搜索。然而,更有效的方法是使用平方,并在指数之间实现指数搜索(每次将幂加倍,2=2^1、4=2^2、16=2^4、256=2^8。 ..)。找到可能的指数范围后,恢复线性搜索。您还可以使用预先计算的值进行优化。

最后,您可以从牛顿迭代开始。对于平方根的情况,您可以使用平方根的倒数来避免除法。不幸的是,这并不能推广到高订单。

最后但并非最不重要的一点是,确定最坏情况下(具有最差初始近似值)所需的迭代次数可能是有益的,并且始终使用该数字,而不是测试具有一定容差的收敛性。

关于c++ - 在 C++ 中不使用 cmath 计算 n 次根的有效方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/77179206/

相关文章:

algorithm - 解决这种逻辑问题的方法是什么?

algorithm - 多短规则模式匹配算法

android - Cmake 测试被忽略 - 在 android 中使用 gradle 构建 cmake 项目

c++ - Boost::Asio 多播监听地址

c++ - 从包含层次结构生成单个包含文件

algorithm - 如何用不相交的旋转矩形填充矩形区域?

C++ 命名空间未声明

java - 动态设置索引。这样所有可能的输入都可以被第一个整除

javascript - 模运算是否在不同的基上起作用?

math - 从平面上的 2D 点转换为 3D