c - 如何实现模幂运算,最多需要两倍于要在 C 中进行模幂运算的数字的字节大小?

标签 c algorithm overflow exponentiation modular-arithmetic

也就是说 m^e mod (n) 是否有这样的算法:

// powmod(m,e,n) = m^e % n
unsigned long long powmod(unsigned long m, unsigned long e, unsigned long long n)

不会溢出,假设 m = 2^32 - 1、e = 3、n = 2^64 - 1 没有使用 gmp 或此类库?

最佳答案

是的,你可以做到这一点。请执行下面的代码,因为有一个内置的 Exp 可供测试;到 C 的翻译应该非常简单,因为唯一的库使用仅限于测试。

package main

import (
    "fmt"
    "math"
    "math/big"
    "math/rand"
)

// AddMod returns a + b (mod m).
func AddMod(a, b, m uint64) uint64 {
    a %= m
    b %= m
    if a >= m-b {
        return a - (m - b)
    }
    return a + b
}

// SubMod returns a - b (mod m).
func SubMod(a, b, m uint64) uint64 {
    a %= m
    b %= m
    if a < b {
        return a + (m - b)
    }
    return a - b
}

// Lshift32Mod returns 2^32 a (mod m).
func Lshift32Mod(a, m uint64) uint64 {
    a %= m
    // Let A = 2^32 a. The desired result is A - q* m, where q* = [A/m].
    // Approximate q* from below by q = [A/(m+err)] for the err in (0, 2^32] such
    // that 2^32|m+err. The discrepancy is
    //
    // q* - q < A (1/m - 1/(m+err)) + 1 = A err/(m (m+err)) + 1
    //
    // A - q m = A - q* m + (q* - q) m < m + A err/(m+err) + m < 2 m + 2^64.
    //
    // We conclude that a handful of loop iterations suffice.
    m0 := m & math.MaxUint32
    m1 := m >> 32
    q := a / (m1 + 1)
    q0 := q & math.MaxUint32
    q1 := q >> 32
    p := q0 * m0
    p0 := p & math.MaxUint32
    p1 := p >> 32
    a -= p1 + q0*m1 + q1*m0 + ((q1 * m1) << 32)
    for a > math.MaxUint32 {
        p0 += m0
        a -= m1
    }
    return SubMod(a<<32, p0, m)
}

// MulMod returns a b (mod m).
func MulMod(a, b, m uint64) uint64 {
    a0 := a & math.MaxUint32
    a1 := a >> 32
    b0 := b & math.MaxUint32
    b1 := b >> 32
    p0 := a0 * b0
    p1 := AddMod(a0*b1, a1*b0, m)
    p2 := a1 * b1
    return AddMod(p0, Lshift32Mod(AddMod(p1, Lshift32Mod(p2, m), m), m), m)
}

// PowMod returns a^b (mod m), where 0^0 = 1.
func PowMod(a, b, m uint64) uint64 {
    r := 1 % m
    for b != 0 {
        if (b & 1) != 0 {
            r = MulMod(r, a, m)
        }
        a = MulMod(a, a, m)
        b >>= 1
    }
    return r
}

func randUint64() uint64 {
    return uint64(rand.Uint32()) | (uint64(rand.Uint32()) << 32)
}

func main() {
    var biga, bigb, bigm, actual, bigmul, expected big.Int
    for i := 1; true; i++ {
        a := randUint64()
        b := randUint64()
        m := randUint64()
        biga.SetUint64(a)
        bigb.SetUint64(b)
        bigm.SetUint64(m)
        actual.SetUint64(MulMod(a, b, m))
        bigmul.Mul(&biga, &bigb)
        expected.Mod(&bigmul, &bigm)
        if actual.Cmp(&expected) != 0 {
            panic(fmt.Sprintf("MulMod(%d, %d, %d): expected %s; got %s", a, b, m, expected.String(), actual.String()))
        }
        if i%10 == 0 {
            actual.SetUint64(PowMod(a, b, m))
            expected.Exp(&biga, &bigb, &bigm)
            if actual.Cmp(&expected) != 0 {
                panic(fmt.Sprintf("PowMod(%d, %d, %d): expected %s; got %s", a, b, m, expected.String(), actual.String()))
            }
        }
        if i%100000 == 0 {
            println(i)
        }
    }
}

上述代码的 C 语言翻译,在主函数中包含边缘测试值:

#include <stdio.h>
// AddMod returns a + b (mod m).
unsigned long long AddMod(unsigned long long a, unsigned long long b, unsigned long long m){
    a %= m;
    b %= m;
    if (a >= m-b) {
        return a - (m - b);
    }
    return a + b;
}

// SubMod returns a - b (mod m).
unsigned long long SubMod(unsigned long long a, unsigned long long b, unsigned long long m){
    a %= m;
    b %= m;
    if (a < b) {
        return a + (m - b);
    }
    return a - b;
}

// Lshift32Mod returns 2^32 a (mod m).
unsigned long long Lshift32Mod(unsigned long long a, unsigned long long m){
    a %= m;
    // Let A = 2^32 a. The desired result is A - q* m, where q* = [A/m].
    // Approximate q* from below by q = [A/(m+err)] for the err in (0, 2^32] such
    // that 2^32|m+err. The discrepancy is
    //
    // q* - q < A (1/m - 1/(m+err)) + 1 = A err/(m (m+err)) + 1
    //
    // A - q m = A - q* m + (q* - q) m < m + A err/(m+err) + m < 2 m + 2^64.
    //
    // We conclude that a handful of loop iterations suffice.
    unsigned long long m0 = m & 0xFFFFFFFF;
    unsigned long long m1 = m >> 32;
    unsigned long long q = a / (m1 + 1);
    unsigned long long q0 = q & 0xFFFFFFFF;
    unsigned long long q1 = q >> 32;
    unsigned long long p = q0 * m0;
    unsigned long long p0 = p & 0xFFFFFFFF;
    unsigned long long p1 = p >> 32;
    a -= p1 + q0*m1 + q1*m0 + ((q1 * m1) << 32);
    while (a > 0xFFFFFFFF) {
        p0 += m0;
        a -= m1;
    }
    return SubMod(a<<32, p0, m);
}

// MulMod returns a b (mod m).
unsigned long long MulMod(unsigned long long a, unsigned long long b, unsigned long long m){

    unsigned long long a0 = a & 0xFFFFFFFF;
    unsigned long long a1 = a >> 32;
    unsigned long long b0 = b & 0xFFFFFFFF;
    unsigned long long b1 = b >> 32;
    unsigned long long p0 = a0 * b0;
    unsigned long long p1 = AddMod(a0*b1, a1*b0, m);
    unsigned long long p2 = a1 * b1;

    return AddMod(p0, Lshift32Mod(AddMod(p1, Lshift32Mod(p2, m), m), m), m);
}

// PowMod returns a^b (mod m), where 0^0 = 1.
unsigned long long PowMod(unsigned long long a, unsigned long long b, unsigned long long m){
    unsigned long long r = 1 % m;
    while (b != 0) {
        if ((b & 1) != 0) {
            r = MulMod(r, a, m);
        }
        a = MulMod(a, a, m);
        b >>= 1;
    }
    return r;
}


int main(void){
    unsigned long long a = 4294967189;
    unsigned long long b = 4294967231;
    unsigned long long m = 18446743979220271189;
    unsigned long long c = 0;

    c = PowMod(a, b, m);
    printf("%llu %llu %llu %llu", a, b, m, c);
}

关于c - 如何实现模幂运算,最多需要两倍于要在 C 中进行模幂运算的数字的字节大小?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41397158/

相关文章:

c - C 中是否有替代 strtoull() 函数的方法?

c - "Too few arguments"错误试图运行我编译的程序

c - WTS_CLIENT_ADDRESS 地址无法正确打印

用 1 和 2 求整数的二进制的算法

java - 查找从 2 到 1000 的所有素数的算法不起作用

html - 在 div 内的 div 上使用 "overflow:hidden;"

css - 使用全宽,不包括 "position: absolute"的溢出滚动条

objective-c - 100 <= x <= 150 作为 if () 中的参数,搞笑

c++ - 同时最小值和最大值

c++ - 没有有用且可靠的方法来检测 C/C++ 中的整数溢出?