如何在不使用 cmath 等的情况下有效计算数字的 n 次根,精确到至少 12 个正确的小数位?
我已经尝试自己解决这个问题了。我的想法是找到一个近似值,并使用牛顿法使近似值更加准确。
我实现了两种方法,一种使用二分搜索,另一种基于快速逆平方根算法。
#include <array>
#include <chrono>
#include <cmath>
#include <iostream>
#include <vector>
using std::chrono::steady_clock;
using std::chrono::duration;
using std::cout;
using std::vector;
float r = 0.0;
inline float power(float base, int exp) {
if (not exp) {
return 1.0;
}
if (exp < 0) {
base = 1 / base;
exp = -exp;
}
float p = 1.0;
while (exp > 1) {
if (exp % 2) {
p = base * p;
}
base *= base;
exp /= 2;
}
return base * p;
}
inline float nth_root(float base, int n) {
float lo, hi, x, p, r, v;
int n1 = n - 1;
lo = 0;
hi = base;
for (int i = 0; i < 12; i++) {
x = (lo + hi) / 2;
p = power(x, n1);
v = x * p - base;
r = n * p;
if (v <= 0) {
lo = x;
}
else {
hi = x;
}
}
x = (lo + hi) / 2;
r = 1.0 / n;
for (int i = 0; i < 12; i++) {
x = r * (n * x + base / power(x, n));
}
return x;
}
inline float fast_nth_root(float base, int n)
{
uint32_t i = std::bit_cast<uint32_t>(base);
float rn = 1.0 / n;
i = 0x3F7A3BEA * rn * (n + 1) - i * rn;
float x = std::bit_cast<float>(i);
for (int j = 0; j < 6; j++) {
x = x * (n + 1 - base * power(x, n)) * rn;
}
return 1.0 / x;
}
int main()
{
vector<float> bases(256);
vector<int> ns(256);
float r256 = 1.0 / 256;
for (int i = 0; i < 256; i++) {
bases[i] = 1.0 + rand() % 16384 + (rand() % 256) * r256;
ns[i] = 2 + rand() % 30;
}
auto start = steady_clock::now();
for (int64_t i = 0; i < 1048576; i++) {
r += nth_root(bases[i % 256], ns[i % 256]);
}
auto end = steady_clock::now();
duration<double, std::nano> time = end - start;
cout << "nth_root: " << time.count() / 1048576 << " nanoseconds\n";
start = steady_clock::now();
for (int64_t i = 0; i < 1048576; i++) {
r += pow(bases[i % 256], 1.0 / ns[i % 256]);
}
end = steady_clock::now();
time = end - start;
cout << "pow: " << time.count() / 1048576 << " nanoseconds\n";
start = steady_clock::now();
for (int64_t i = 0; i < 1048576; i++) {
r += fast_nth_root(bases[i % 256], ns[i % 256]);
}
end = steady_clock::now();
time = end - start;
cout << "fast_nth_root: " << time.count() / 1048576 << " nanoseconds\n";
}
编译:
g++.exe -Wall -fexceptions -fomit-frame-pointer -fexpensive-optimizations -flto -O3 -m64 --std=c++20 -march=native -ffast-math -c D:\MyScript\CodeBlocks\testapp\main.cpp -o obj\Release\main.o
g++.exe -o bin\Release\testapp.exe obj\Release\main.o -O3 -flto -s -static-libstdc++ -static-libgcc -static -m64
PS C:\Users\Xeni> D:\MyScript\CodeBlocks\testapp\bin\Release\testapp.exe
nth_root: 318.12 nanoseconds
pow: 104.77 nanoseconds
fast_nth_root: 53.4222 nanoseconds
正如预期的那样,第一种方法非常慢,但尽管第二种方法比库代码快,但可能不那么准确。
根据我在 Python 中的测试:
import random, struct
def root(num, p, lim, lin):
lo = 0
hi = num
for _ in range(lim):
x = (lo + hi) / 2
po = x ** (p - 1)
v = x * po - num
r = p * po
if v <= 0:
lo = x
else:
hi = x
x = (lo + hi) / 2
r = 1 / p
p -= 1
for _ in range(lin):
x = r * (p * x + num / x ** p)
return x
stats = {}
for _ in range(64):
n = random.randrange(1, 16384)
p = random.randrange(2, 32)
x = n ** (1 / p)
i = 2
j = 2
b = 0
while abs((y := root(n, p, i, j)) - x) > 1e-13:
if i < 16 and b or i >= 16:
j += 1
else:
i += 1
b = not b
stats[(n, p)] = (i, j, x, y)
def fast_nth_root(x: float, n: int, lim: int) -> float:
t = int.from_bytes(struct.pack('>f', x), 'big')
t = round(0x3F7A3BEA / n * (n + 1) - t / n)
t = struct.unpack('>f', struct.pack('>i', t))[0]
for _ in range(lim):
t = t * (n + 1 - x * t ** n) / n
return 1 / t
stats1 = {}
for _ in range(64):
n = random.randrange(1, 16384)
p = random.randrange(2, 32)
x = n ** (1 / p)
i = 2
while abs((y := fast_nth_root(n, p, i)) - x) > 1e-13:
i += 1
stats1[(n, p)] = (i, x, y)
In [1434]: stats
Out[1434]:
{(7162, 13): (11, 10, 1.979434344458145, 1.979434344458145),
(15510, 2): (5, 5, 124.53915047084591, 124.53915047084595),
(3054, 25): (10, 9, 1.3784618821753079, 1.3784618821753076),
(1601, 25): (9, 8, 1.3433081611539914, 1.3433081611540099),
(3522, 21): (10, 9, 1.475348878994169, 1.475348878994169),
(15107, 14): (12, 11, 1.9884410975573956, 1.9884410975573954),
(15200, 16): (12, 11, 1.8254301706001883, 1.825430170600188),
(1900, 15): (9, 8, 1.6541830766301984, 1.6541830766301981),
(16145, 20): (12, 11, 1.6233116388185762, 1.623311638818576),
(2580, 4): (7, 6, 7.126969930959522, 7.126969930959522),
(1702, 27): (11, 10, 1.3172407839773748, 1.3172407839773748),
(9875, 29): (13, 13, 1.3732280275029944, 1.3732280275029942),
(15687, 15): (12, 11, 1.9041565885196923, 1.904156588519692),
(5774, 16): (12, 11, 1.718273525571221, 1.718273525571221),
(6186, 2): (5, 4, 78.65112840894274, 78.65112840894277),
(4476, 23): (12, 11, 1.441233509663115, 1.441233509663115),
(4161, 24): (12, 11, 1.4151416228042228, 1.4151416228042226),
(16116, 13): (12, 11, 2.1068575543742956, 2.1068575543742956),
(12380, 14): (11, 11, 1.9603661231094314, 1.9603661231094311),
(9736, 19): (13, 12, 1.621491836717726, 1.6214918367177258),
(8612, 26): (13, 12, 1.4169357394205302, 1.4169357394205302),
(4586, 7): (9, 8, 3.334740217355978, 3.3347402173559777),
(5232, 24): (12, 12, 1.428711330576587, 1.4287113305765868),
(14698, 17): (12, 11, 1.7584613929955697, 1.7584613929955695),
(4931, 13): (10, 9, 1.923410237452901, 1.9234102374529014),
(7391, 4): (8, 7, 9.272050761175075, 9.272050761175075),
(9949, 6): (9, 9, 4.637635073009885, 4.637635073009886),
(4767, 18): (12, 11, 1.6008364077669808, 1.6008364077669806),
(16318, 8): (11, 10, 3.3618889684623863, 3.3618889684623863),
(7610, 28): (13, 12, 1.3760077520394016, 1.3760077520394014),
(13632, 6): (10, 9, 4.887573066390476, 4.887573066390476),
(8380, 21): (11, 11, 1.5375213098103222, 1.5375213098103224),
(7247, 14): (11, 10, 1.88679879448582, 1.8867987944858202),
(11343, 18): (13, 12, 1.6798196720467486, 1.6798196720467486),
(6468, 17): (11, 10, 1.6755714114675964, 1.6755714114675964),
(11801, 6): (10, 9, 4.771480299610415, 4.771480299610415),
(441, 28): (9, 8, 1.2429230307022932, 1.2429230307022932),
(15341, 14): (12, 11, 1.9906254301097404, 1.9906254301097404),
(8501, 20): (11, 10, 1.5720758667453518, 1.5720758667453518),
(2777, 19): (10, 10, 1.5178918732605664, 1.5178918732605664),
(14842, 30): (14, 13, 1.3773672345540857, 1.3773672345540857),
(6149, 28): (11, 10, 1.3655715058830975, 1.3655715058830973),
(13374, 21): (12, 11, 1.5721306482454025, 1.5721306482454025),
(9947, 30): (13, 13, 1.3591156205784112, 1.359115620578415),
(14423, 16): (12, 11, 1.8194535606369682, 1.8194535606369682),
(9341, 31): (13, 12, 1.3430036888404402, 1.3430036888404402),
(14558, 7): (10, 9, 3.9330441035217714, 3.9330441035217714),
(152, 16): (8, 7, 1.3688795144738382, 1.368879514473838),
(13593, 18): (12, 11, 1.6967920812890351, 1.6967920812890351),
(2834, 7): (8, 8, 3.113149507653915, 3.1131495076539637),
(11545, 14): (11, 10, 1.950612466604169, 1.9506124666041689),
(12416, 21): (12, 11, 1.566576147387506, 1.5665761473875057),
(8998, 6): (9, 8, 4.560624662646501, 4.560624662646504),
(5245, 27): (11, 10, 1.3733092699013858, 1.3733092699013856),
(5693, 29): (11, 10, 1.3473937347050562, 1.3473937347050562),
(3508, 26): (10, 10, 1.3688266297365765, 1.368826629736599),
(16237, 9): (11, 10, 2.936526854011741, 2.936526854011741),
(2911, 6): (8, 7, 3.778682197915392, 3.7786821979153924),
(387, 22): (10, 9, 1.311061987204142, 1.311061987204142),
(3324, 4): (7, 7, 7.593032412784651, 7.593032412784651),
(15300, 22): (12, 11, 1.5495773040868939, 1.5495773040868939),
(5469, 26): (11, 10, 1.3924053694535468, 1.392405369453547),
(1195, 13): (8, 8, 1.7247279767538015, 1.724727976753898),
(7998, 13): (11, 10, 1.9963162339549785, 1.9963162339549787)}
In [1435]: stats1
Out[1435]:
{(10882, 3): (4, 22.159990703206965, 22.159990703206965),
(6673, 28): (6, 1.3695657835909767, 1.3695657835909767),
(4803, 10): (4, 2.3342709144708604, 2.3342709144708604),
(1802, 27): (5, 1.3200291160996098, 1.3200291160996294),
(8380, 15): (4, 1.8262053100463662, 1.826205310046366),
(12898, 21): (5, 1.569419919895426, 1.5694199198954262),
(10227, 2): (4, 101.12863096077193, 101.12863096077193),
(4857, 25): (5, 1.4042832772764529, 1.4042832772765208),
(1351, 12): (4, 1.8234251715932501, 1.8234251715932501),
(10180, 16): (3, 1.7802632882832108, 1.7802632882832108),
(6948, 28): (6, 1.37154252932099, 1.3715425293209902),
(13901, 10): (4, 2.5959994991170756, 2.5959994991171),
(7513, 21): (5, 1.529545998990047, 1.529545998990047),
(7902, 18): (4, 1.6464211857566509, 1.6464211857566777),
(3277, 31): (6, 1.2983819395882321, 1.2983819395882321),
(3499, 10): (4, 2.261488555258189, 2.2614885552581887),
(15234, 30): (6, 1.3785646304878574, 1.3785646304878574),
(5739, 13): (5, 1.9459928850146935, 1.9459928850146935),
(1823, 24): (5, 1.3673072329614473, 1.367307232961453),
(15105, 16): (4, 1.8247150144295399, 1.8247150144295399),
(16215, 12): (3, 2.2429852247131876, 2.2429852247131876),
(15844, 20): (5, 1.6217848596088165, 1.6217848596088162),
(15677, 26): (5, 1.4499608216310222, 1.4499608216310922),
(11839, 22): (5, 1.5316187797414815, 1.5316187797414818),
(10163, 26): (6, 1.4259891723870766, 1.4259891723870766),
(1550, 18): (5, 1.5039751132184096, 1.5039751132184096),
(15194, 5): (4, 6.8601628355219795, 6.860162835521979),
(15612, 24): (5, 1.4952969217858556, 1.495296921785857),
(9469, 12): (4, 2.1446611088057317, 2.144661108805732),
(4030, 20): (5, 1.5144859625611744, 1.5144859625611742),
(11729, 3): (4, 22.720627875592783, 22.72062787559279),
(12709, 29): (6, 1.3852274277113097, 1.3852274277113097),
(12263, 31): (6, 1.354846884624788, 1.354846884624788),
(6372, 9): (4, 2.6466548424778766, 2.6466548424779086),
(7119, 5): (4, 5.8949998528103436, 5.8949998528103436),
(10737, 27): (6, 1.4102365332846034, 1.4102365332846034),
(2231, 11): (4, 2.01562182326949, 2.0156218232694902),
(412, 9): (4, 1.952289125066342, 1.9522891250663446),
(8417, 5): (4, 6.095810609109851, 6.095810609109851),
(6759, 31): (6, 1.3290600231725829, 1.3290600231725826),
(2207, 23): (5, 1.3975994148304356, 1.3975994148304371),
(4755, 16): (4, 1.6975473914419268, 1.697547391441927),
(7978, 13): (4, 1.995931787083301, 1.9959317870833602),
(14957, 19): (4, 1.658550339552885, 1.658550339552935),
(745, 28): (5, 1.2664178106847905, 1.2664178106847914),
(2696, 15): (4, 1.6932249497186587, 1.6932249497186587),
(5484, 7): (4, 3.421029175738748, 3.421029175738748),
(15410, 10): (4, 2.6228911266357167, 2.6228911266357264),
(315, 10): (3, 1.7775877772276876, 1.7775877772276876),
(9252, 13): (4, 2.0188081438649, 2.0188081438649035),
(1562, 27): (5, 1.313059731029047, 1.3130597310290748),
(9803, 8): (4, 3.1544225980493534, 3.154422598049354),
(14443, 16): (4, 1.8196111450479415, 1.8196111450479413),
(5033, 23): (5, 1.4486017214956801, 1.4486017214956872),
(16175, 2): (4, 127.18097341976905, 127.18097341976903),
(15125, 5): (4, 6.853920721113647, 6.853920721113645),
(16292, 19): (4, 1.6660301757087943, 1.666030175708797),
(11486, 20): (5, 1.5959101637580406, 1.5959101637580406),
(13824, 7): (4, 3.9040835527337374, 3.9040835527337383),
(8604, 3): (4, 20.491172086350442, 20.491172086350442),
(1225, 3): (4, 10.699874805650794, 10.699874805650795),
(9163, 21): (5, 1.54407524840236, 1.5440752484023603),
(7833, 21): (5, 1.53258704058841, 1.53258704058841),
(8425, 24): (5, 1.4573551932369144, 1.4573551932369178)}
平均而言,第一种方法需要大约 12 次二分搜索迭代和 12 次牛顿法迭代才能使误差低于 10-13,而第二种方法则需要 6 次牛顿法迭代方法获得相同的精度。
有没有办法让代码在相同迭代次数下运行得更快,或者有办法加快所涉及数学的收敛速度?
这不是作业。这是一个 self 施加的编程挑战。
最佳答案
您没有说明您的函数应覆盖哪些值范围,也没有说明这些值的统计分布,因此我们必须建议一种通用方法。
该策略是首先通过研究数量级来找到牛顿迭代的良好初始近似值。
如果您可以访问浮点表示形式的指数(通过破解二进制表示形式),一个好的起始值是 √2/2 乘以 2^(exponent/n)。我们选择 √2/2 而不是 1 来以 [1/2, 1) 为中心。为了提高效率,您可以为所有指数和所有根阶预先计算这些常数。不管怎样,为了节省空间,最好将 n 上的指数分解为整数商和余数。
如果指数不可用,那么您可以从 1 开始,通过连续加倍(或减半)来搜索它(因此 1=2^0、2=2^1、4=2^2、8=2 ^3...)。这相当于指数之间的线性搜索。然而,更有效的方法是使用平方,并在指数之间实现指数搜索(每次将幂加倍,2=2^1、4=2^2、16=2^4、256=2^8。 ..)。找到可能的指数范围后,恢复线性搜索。您还可以使用预先计算的值进行优化。
最后,您可以从牛顿迭代开始。对于平方根的情况,您可以使用平方根的倒数来避免除法。不幸的是,这并不能推广到高订单。
最后但并非最不重要的一点是,确定最坏情况下(具有最差初始近似值)所需的迭代次数可能是有益的,并且始终使用该数字,而不是测试具有一定容差的收敛性。
关于c++ - 在 C++ 中不使用 cmath 计算 n 次根的有效方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/77179206/