考虑下面的函数,它将 a * b 的结果转换成一对数字 i 和 j,其中:
- a, b, x, y 是 int(假设它们总是 => 32 位长)
- a 和 b <= n*m,其中 n = 10^3 和 m=10^5。 n*m = 基础。
- a * b 可以写成 i*BASE + j
你将如何计算 j 而不使用任何大于 int 的类型(以防溢出 int 的 UB):
#include <iostream>
#include <cstdlib>
using namespace std;
int n = 1000, m = 100000;
struct N {
int i, j;
};
N f(int a, int b) {
N x;
int a0, a1, b0, b1, o;
a1 = a / n;
a0 = a - (a1 * n); // a0 = a % n
b1 = b / m;
b0 = b - (b1 * m); // b0 = b % m
o = a1 * b1 + (a0 * b1) / n + (b0 * a1) / m;
x.i = o;
x.j = 0; // CALCULATE J WITH INTs MATH
return x;
}
int main(int, char* argv[]) {
int a = atoi(argv[1]),
b = atoi(argv[2]);
N x = f(a, b);
cout << a << " * " << b << " = " << x.i << "*" << n*m
<< " + " << x.j << endl;
cout << "which is: " << (long long)a * b << endl;
return 0;
}
最佳答案
你的开始是正确的,但是失去了关于计算 o
的情节.首先,我的假设:您不想处理任何大于 n*m
的整数。 , 所以取 mod n*m
是作弊。我这么说是因为给出 m > 2^16
,我必须假设 int 是 32 位长,它能够处理您的数字而不会溢出。
无论如何。您正确地(我猜,因为未指定 n
和 m
的目的):
a=a0 + a1*n (a0<n)
b=b0 + b1*m (b0<m)
所以,如果我们算一下:
a*b = a0*b0 + a0*b1*m + a1*b0*n + a1*b1*n*m
在这里,a0*b0 < n*m
, 所以它是 j
的一部分, 和 a1*b1*n*m > n*m
, 所以它是 i
的一部分.您需要将另外两个术语再次拆分为两个。但是你不能计算每个并取 mod n*m
,因为那将是作弊(根据我上面的规则)。如果你写:
a0*b1 = a0b1_0 + a0b1_1*n
你得到:
a0*b1*m = a0b1_0*m + a0b1_1*n*m
自 a0b1_0 < n
, a0b1_0*m < n*m
,这意味着这部分转到 j
.显然,a0b1_1
转到 i.
对 a1*b0 重复类似的逻辑,您就得到了 j
的三个项。 , 还有三个加起来 i
.
编辑:忘了提一些事情:
您需要约束
a < n^2
和b < m^2
为此工作。否则,您需要更多的 ai “单词”。例如:a = a0 + a1*n + a2*n^2, ai < n
.j
的最终总和可能大于n*m
.您需要注意溢出(n*m - o < addend
或类似的逻辑,并在发生这种情况时将1
添加到i
- 同时计算j + addend - n*m
而不会溢出)。
关于c++ - 你会如何计算这个函数中的j?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/6424180/