c - 椭圆曲线离散对数

标签 c segmentation-fault elliptic-curve ecdsa

我正在尝试使用 Pollard rho 求解椭圆曲线离散对数(找到 k where G=kp),所以我搜索了 c 中的实现,然后我找到了一个在 main 函数中添加问题特定数据我得到了 segmentation fault (core dumped)

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <gmp.h>
#include <limits.h>
#include <sys/time.h>

#include <openssl/ec.h>
#include <openssl/bn.h>
#include <openssl/obj_mac.h> // for NID_secp256k1

#define POLLARD_SET_COUNT 16

#if defined(WIN32) || defined(_WIN32)
#define EXPORT __declspec(dllexport)
#else
#define EXPORT
#endif

#define MAX_RESTART 100

int ec_point_partition(const EC_GROUP *ecgrp, const EC_POINT *x) {  

    size_t len = EC_POINT_point2oct( ecgrp, x, POINT_CONVERSION_UNCOMPRESSED, NULL, 0, NULL );
    unsigned char ret[len]; 
    EC_POINT_point2oct( ecgrp, x, POINT_CONVERSION_UNCOMPRESSED, ret, len, NULL );

    int id = ( ret[len - 1] & 0xFF ) % POLLARD_SET_COUNT;

    return id;
}

// P generator 
// Q result*P
// order of the curve
// result
//Reference: J. Sattler and C. P. Schnorr, "Generating random walks in groups"

int elliptic_pollard_rho_dlog(const EC_GROUP *group, const EC_POINT *P, const EC_POINT *Q, const BIGNUM *order, BIGNUM *res) {

    printf("Pollard rho discrete log algorithm... \n");

    BN_CTX* ctx;
    ctx = BN_CTX_new();

    int i, j;
    int iterations = 0;

    if ( !EC_POINT_is_on_curve(group, P, ctx ) || !EC_POINT_is_on_curve(group, Q, ctx ) ) return 1;

    EC_POINT *X1 = EC_POINT_new(group);
    EC_POINT *X2 = EC_POINT_new(group);

    BIGNUM *c1 = BN_new();
    BIGNUM *d1 = BN_new();
    BIGNUM *c2 = BN_new();
    BIGNUM *d2 = BN_new();

    BIGNUM* a[POLLARD_SET_COUNT];
    BIGNUM* b[POLLARD_SET_COUNT];
    EC_POINT* R[POLLARD_SET_COUNT];

    BN_zero(c1); BN_zero(d1);
    BN_zero(c2); BN_zero(d2);


    for (i = 0; i < POLLARD_SET_COUNT; i++) {   

        a[i] = BN_new();
        b[i] = BN_new();
        R[i] = EC_POINT_new(group);

        BN_rand_range(a[i], order);     
        BN_rand_range(b[i], order);

        // R = aP + bQ

        EC_POINT_mul(group, R[i], a[i], Q, b[i], ctx);
        //ep_norm(R[i], R[i]);
    }

    BN_rand_range(c1, order);       
    BN_rand_range(d1, order);       


    // X1 = c1*P + d1*Q
    EC_POINT_mul(group, X1, c1, Q, d1,  ctx);  
    //ep_norm(X1, X1);

    BN_copy(c2, c1);
    BN_copy(d2, d1);
    EC_POINT_copy(X2, X1);


    double work_time = (double) clock();
    do {
        j = ec_point_partition(group, X1);
        EC_POINT_add(group, X1, X1, R[j], ctx);

        BN_mod_add(c1, c1, a[j], order, ctx); 

        BN_mod_add(d1, d1, b[j], order, ctx); 

        for (i = 0; i < 2; i++) {
            j = ec_point_partition(group, X2);

            EC_POINT_add(group, X2, X2, R[j], ctx);

            BN_mod_add(c2, c2, a[j], order, ctx); 

            BN_mod_add(d2, d2, b[j], order, ctx);
        }

        iterations++;
        printf("Iteration %d \r",iterations );
    } while ( EC_POINT_cmp(group, X1, X2, ctx) != 0 ) ;


    printf("\n ");

    work_time = ( (double) clock() - work_time ) / (double)CLOCKS_PER_SEC;

    printf("Number of iterations %d %f\n",iterations, work_time );

    BN_mod_sub(c1, c1, c2, order, ctx);
    BN_mod_sub(d2, d2, d1, order, ctx);

    if (BN_is_zero(d2) == 1) return 1;


    //d1 = d2^-1 mod order  
    BN_mod_inverse(d1, d2, order, ctx);

    BN_mod_mul(res, c1, d1, order, ctx);

    for (int k = 0; k < POLLARD_SET_COUNT; ++k) {
        BN_free(a[k]); 
        BN_free(b[k]);
        EC_POINT_free(R[k]);
    }
    BN_free(c1); BN_free(d1);
    BN_free(c2); BN_free(d2);
    EC_POINT_free(X1); EC_POINT_free(X2);

    BN_CTX_free(ctx);
    return 0;
}


int main(int argc, char *argv[])
{
    unsigned char *p_str="134747661567386867366256408824228742802669457";
    unsigned char *a_str="-1";
    unsigned char *b_str="0";
    BIGNUM *p = BN_bin2bn(p_str, sizeof(p_str), NULL);
    BIGNUM *a = BN_bin2bn(a_str, sizeof(a_str), NULL);
    BIGNUM *b = BN_bin2bn(b_str, sizeof(b_str), NULL);
    BN_CTX* ctx;
    ctx = BN_CTX_new();
    EC_GROUP* g = EC_GROUP_new(EC_GFp_simple_method());
    EC_GROUP_set_curve_GFp(g,p,a,b,ctx);    
    unsigned char *XP_str="18185174461194872234733581786593019886770620";
    unsigned char *YP_str="74952280828346465277451545812645059041440154";

    BN_CTX* ctx1;
    ctx1 = BN_CTX_new();
    BIGNUM *XP = BN_bin2bn(XP_str, sizeof(XP_str), NULL);
    BIGNUM *YP = BN_bin2bn(YP_str, sizeof(YP_str), NULL);
    EC_POINT* P = EC_POINT_new(g);
    EC_POINT_set_affine_coordinates_GFp(g,P,XP,YP,ctx1);

    unsigned char *XQ_str="76468233972358960368422190121977870066985660";
    unsigned char *YQ_str="33884872380845276447083435959215308764231090";
    BIGNUM* XQ = BN_bin2bn(XQ_str, sizeof(XQ_str), NULL);
    BIGNUM* YQ = BN_bin2bn(YQ_str, sizeof(YQ_str), NULL);
    EC_POINT *Q = EC_POINT_new(g);
    BN_CTX* ctx2;
    ctx2 = BN_CTX_new();
    EC_POINT_set_affine_coordinates_GFp(g,Q,XQ,YQ,ctx2);
    char * str;


    unsigned char *N_str="2902021510595963727029";
    BIGNUM *N = BN_bin2bn(N_str, sizeof(N_str), NULL);
    BIGNUM *res;
    elliptic_pollard_rho_dlog (g,P,Q,N,res);
    BN_bn2mpi(res,str); 
    printf("%s\n", str);


  return 0;
}

这是导致段错误的语句

    BN_bn2mpi(res,str); 

最佳答案

第 1 部分。 Python 版本。

更新:请参阅我的回答的新第 2 部分,其中我展示了与此 Python 版本相同算法的 C++ 版本。

你的任务很有趣!

也许您希望修复您的代码,但我决定从头开始实现纯 Python(答案的第 1 部分)和纯 C++(第 2 部分)解决方案,而不使用任何外部非标准模块。我认为这种没有依赖性的从头开始的解决方案对于教育目的非常有用。

这样的算法相当复杂,Python 很容易在短时间内实现这样的算法。

在下面的代码中,我使用维基百科的帮助来实现 Pollard's Rho Discrete LogarithmElliptic Curve Point Multiplication .

代码不依赖于任何外部模块,它只使用几个内置的 Python 模块。有可能使用 gmpy2模块,如果你通过 python -m pip install gmpy2 安装它并在代码中取消注释行 #import gmpy2

您可能会看到我自己生成随机基点并计算其顺序。我不使用任何外部曲线,如比特币的 secp256k1 , 或其他 standard curves .

main()函数的开头你可以看到我设置了bits = 24,这是曲线素数模数的位数,曲线阶数(不同点的数量)将具有大致相同的位大小。您可以将其设置为 bits = 32 以尝试解决更大曲线的任务。

众所周知,算法的复杂度为 O(Sqrt(Curve_Order)),它需要添加这么多椭圆曲线点。加点不是原始操作,也需要一些时间。因此,bits = 32 的曲线顺序位大小的算法运行大约需要 10-15 秒。虽然 bits = 64 对 Python 来说需要很长时间,但 C++ 版本(我将在稍后实现)将足够快,可以在一个小时左右的时间内破解 64 位。

有时您可能会注意到在运行代码时它显示 Pollard Rho 失败了几次,如果算法在 Pollard Rho 的最后一步和计算时 Infinite Point作为椭圆曲线点相加的结果。同样的故障在常规Pollard Rho Integer Factorization中也时有发生。当GCD等于N时。

Try it online!

import random
#random.seed(10)

class ECPoint:
    gmpy2 = None
    #import gmpy2
    import random
    
    class InvError(Exception):
        def __init__(self, *args):
            self.value = args
    
    @classmethod
    def Int(cls, x):
        return int(x) if cls.gmpy2 is None else cls.gmpy2.mpz(x)
    
    @classmethod
    def fermat_prp(cls, n, trials = 32):
        # https://en.wikipedia.org/wiki/Fermat_primality_test
        if n <= 16:
            return n in (2, 3, 5, 7, 11, 13)
        for i in range(trials):
            if pow(cls.random.randint(2, n - 2), n - 1, n) != 1:
                return False
        return True
    
    @classmethod
    def rand_prime(cls, bits):
        while True:
            p = cls.random.randrange(1 << (bits - 1), 1 << bits) | 1
            if cls.fermat_prp(p):
                return p
    
    @classmethod
    def base_gen(cls, bits = 128, *, min_order_pfactor = 0):
        while True:
            while True:
                N = cls.rand_prime(bits)
                if N % 4 != 3:
                    continue
                x0, y0, A = [cls.random.randrange(1, N) for i in range(3)]
                B = (y0 ** 2 - x0 ** 3 - A * x0) % N
                y0_calc = pow(x0 ** 3 + A * x0 + B, (N + 1) // 4, N)
                if y0 == y0_calc:
                    break
            bp = ECPoint(A, B, N, x0, y0, calc_q = True)
            if bp.q is not None and min(bp.q_ps) >= min_order_pfactor:
                break
        assert bp.q > 1 and (bp.q + 1) * bp == bp
        return bp
    
    def __init__(self, A, B, N, x, y, *, q = 0, prepare = True, calc_q = False):
        if prepare:
            N = self.Int(N)
            assert (x is None) == (y is None), (x, y)
            A, B, x, y, q = [(self.Int(e) % N if e is not None else None) for e in [A, B, x, y, q]]
            assert (4 * A ** 3 + 27 * B ** 2) % N != 0
            assert N % 4 == 3
            if x is not None:
                assert (y ** 2 - x ** 3 - A * x - B) % N == 0, (hex(N), hex((y ** 2 - x ** 3 - A * x) % N))
                assert y == pow(x ** 3 + A * x + B, (N + 1) // 4, N)
        self.A, self.B, self.N, self.x, self.y, self.q = A, B, N, x, y, q
        if calc_q:
            self.q, self.q_ps = self.find_order()
    
    def copy(self):
        return ECPoint(self.A, self.B, self.N, self.x, self.y, q = self.q, prepare = False)
    
    def inf(self):
        return ECPoint(self.A, self.B, self.N, None, None, q = self.q, prepare = False)
    
    def find_order(self, *, _m = 1, _ps = []):
        if 1:
            try:
                r = _m * self
            except self.InvError:
                return _m, _ps
            B = 2 * self.N
            for p in self.gen_primes():
                if p * p > B * 2:
                    return None, []
                assert _m % p != 0, (_m, p)
                assert p <= B, (p, B)
                hi = 1
                try:
                    for cnt in range(1, 1 << 60):
                        hi *= p
                        if hi > B:
                            cnt -= 1
                            break
                        r = p * r
                except self.InvError:
                    return self.find_order(_m = hi * _m, _ps = [p] * cnt + _ps)
        else:
            # Alternative slower way
            r = self
            for i in range(1 << 60):
                try:
                    r = r + self
                except self.InvError:
                    return i + 2, []
    
    @classmethod
    def gen_primes(cls, *, ps = [2, 3]):
        yield from ps
        for p in range(ps[-1] + 2, 1 << 60, 2):
            is_prime = True
            for e in ps:
                if e * e > p:
                    break
                if p % e == 0:
                    is_prime = False
                    break
            if is_prime:
                ps.append(p)
                yield ps[-1]
    
    def __add__(self, other):
        if self.x is None:
            return other.copy()
        if other.x is None:
            return self.copy()
        A, B, N, q = self.A, self.B, self.N, self.q
        Px, Py, Qx, Qy = self.x, self.y, other.x, other.y
        if Px == Qx and Py == Qy:
            s = ((Px * Px * 3 + A) * self.inv(Py * 2, N)) % N
        else:
            s = ((Py - Qy) * self.inv(Px - Qx, N)) % N
        x = (s * s - Px - Qx) % N
        y = (s * (Px - x) - Py) % N
        return ECPoint(A, B, N, x, y, q = q, prepare = False)
    
    def __rmul__(self, other):
        assert other > 0, other
        if other == 1:
            return self.copy()
        other = self.Int(other - 1)
        r = self
        while True:
            if other & 1:
                r = r + self
                if other == 1:
                    return r
            other >>= 1
            self = self + self
    
    @classmethod
    def inv(cls, a, n):
        a %= n
        if cls.gmpy2 is None:
            try:
                return pow(a, -1, n)
            except ValueError:
                import math
                raise cls.InvError(math.gcd(a, n), a, n)
        else:
            g, s, t = cls.gmpy2.gcdext(a, n)
            if g != 1:
                raise cls.InvError(g, a, n)
            return s % n

    def __repr__(self):
        return str(dict(x = self.x, y = self.y, A = self.A, B = self.B, N = self.N, q = self.q))

    def __eq__(self, other):
        for i, (a, b) in enumerate([(self.x, other.x), (self.y, other.y), (self.A, other.A), (self.B, other.B), (self.N, other.N), (self.q, other.q)]):
            if a != b:
                return False
        return True

def pollard_rho_ec_log(a, b, bp):
    # https://en.wikipedia.org/wiki/Pollard%27s_rho_algorithm_for_logarithms#Algorithm
    import math

    for itry in range(1 << 60):    
        try:
            i = -1
            part_p = bp.rand_prime(max(3, int(math.log2(bp.N) / 2)))
            
            def f(x):
                mod3 = ((x.x or 0) % part_p) % 3
                if mod3 == 0:
                    return b + x
                elif mod3 == 1:
                    return x + x
                elif mod3 == 2:
                    return a + x
                else:
                    assert False
            
            def g(x, n):
                mod3 = ((x.x or 0) % part_p) % 3
                if mod3 == 0:
                    return n
                elif mod3 == 1:
                    return (2 * n) % bp.q
                elif mod3 == 2:
                    return (n + 1) % bp.q
                else:
                    assert False
            
            def h(x, n):
                mod3 = ((x.x or 0) % part_p) % 3
                if mod3 == 0:
                    return (n + 1) % bp.q
                elif mod3 == 1:
                    return (2 * n) % bp.q
                elif mod3 == 2:
                    return n
                else:
                    assert False
            
            a0, b0, x0 = 0, 0, bp.inf()
            aim1, bim1, xim1 = a0, b0, x0
            a2im2, b2im2, x2im2 = a0, b0, x0
            
            for i in range(1, 1 << 60):
                xi = f(xim1)
                ai = g(xim1, aim1)
                bi = h(xim1, bim1)
                
                x2i = f(f(x2im2))
                a2i = g(f(x2im2), g(x2im2, a2im2))
                b2i = h(f(x2im2), h(x2im2, b2im2))
                
                if xi == x2i:
                    return (bp.inv(bi - b2i, bp.q) * (a2i - ai)) % bp.q
                
                xim1,  aim1, bim1 = xi, ai, bi
                x2im2, a2im2, b2im2 = x2i, a2i, b2i
        except bp.InvError as ex:
            print(f'Try {itry:>4}, Pollard-Rho failed, invert err at iter {i:>7},', ex.value)

def main():
    import random, math
    bits = 24
    print('Generating base point, wait...')
    bp = ECPoint.base_gen(bits, min_order_pfactor = 10)
    print('order', bp.q, '=', ' * '.join([str(e) for e in bp.q_ps]))
    k0, k1 = [random.randrange(1, bp.q) for i in range(2)]
    a = k0 * bp
    x = k1
    b = x * a
    x_calc = pollard_rho_ec_log(a, b, bp)
    print('our x', x, 'found x', x_calc)
    print('equal points:', x * a == x_calc * a)

if __name__ == '__main__':
    main()

输出:

Generating base point, wait...
order 5805013 = 19 * 109 * 2803
Try    0, Pollard-Rho failed, invert err at iter    1120, (109, 1411441, 5805013)
Try    1, Pollard-Rho failed, invert err at iter    3992, (19, 5231802, 5805013)
our x 990731 found x 990731
equal points: True

第 2 部分。 C++ 版本。

与上面几乎相同的代码,但用 C++ 重写。

这个 C++ 版本比 Python 快得多,C++ 代码在 1 Ghz CPU 上花费大约 1 分钟来破解 48 位曲线。 Python 在 32 位曲线上花费的时间相同。

提醒一下,复杂度是 O(Sqrt(Curve_Order)) 这意味着如果 C++ 花费相同的时间处理 48 位(sqrt 是 2^24)和 Python 处理 32 位(sqrt 是 2^ 16) 那么 C++ 大约比 Python 的版本快 2^24/2^16 = 2^8 = 256 倍。

以下版本只能在 CLang 中编译,因为它使用 128 位和 192 位整数。在 GCC也存在 __int128 但没有 192/256 整数。 192 位 int 仅在 BarrettMod() 函数中使用,因此如果您将此函数的主体替换为 return x % n; 那么您就不需要 256 位 int然后你就可以在 GCC 中编译了。

我实现了 Barrett Reduction算法,用基于乘法/移位/减法的特殊 Barrett 公式代替基于慢除法的取模运算 (% N)。这可将模运算提高数倍。

Try it online!

#include <cstdint>
#include <random>
#include <stdexcept>
#include <type_traits>
#include <iomanip>
#include <iostream>
#include <string>
#include <chrono>
#include <cmath>

using u64 = uint64_t;
using u128 = unsigned __int128;
using u192 = unsigned _ExtInt(192);
using Word = u64;
using DWord = u128;
using SWord = std::make_signed_t<Word>;
using TWord = u192;

#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { g_log << " LN " << __LINE__ << " " << std::flush; }
#define DUMP(x) { g_log << " " << (#x) << " = " << (x) << " " << std::flush; }

static auto & g_log = std::cout;

class ECPoint {
public:
    class InvError : public std::runtime_error {
    public:
        InvError(Word const & gcd, Word const & x, Word const & mod)
            : std::runtime_error("(gcd " + std::to_string(gcd) + ", x " + std::to_string(x) +
                ", mod " + std::to_string(mod) + ")") {}
    };
    
    static Word pow_mod(Word a, Word b, Word const & c) {
        // https://en.wikipedia.org/wiki/Modular_exponentiation
        Word r = 1;
        while (b != 0) {
            if (b & 1)
                r = (DWord(r) * a) % c;
            a = (DWord(a) * a) % c;
            b >>= 1;
        }
        return r;
    }
    
    static Word rand_range(Word const & begin, Word const & end) {
        u64 const seed = (u64(std::random_device{}()) << 32) + std::random_device{}();
        thread_local std::mt19937_64 rng{seed};
        ASSERT(begin < end);
        return std::uniform_int_distribution<Word>(begin, end - 1)(rng);
    }
    
    static bool fermat_prp(Word const & n, size_t trials = 32) {
        // https://en.wikipedia.org/wiki/Fermat_primality_test
        if (n <= 16)
            return n == 2 || n == 3 || n == 5 || n == 7 || n == 11 || n == 13;
        for (size_t i = 0; i < trials; ++i)
            if (pow_mod(rand_range(2, n - 2), n - 1, n) != 1)
                return false;
        return true;
    }
    
    static Word rand_prime_range(Word begin, Word end) {
        while (true) {
            Word const p = rand_range(begin, end) | 1;
            if (fermat_prp(p))
                return p;
        }
    }
    
    static Word rand_prime(size_t bits) {
        return rand_prime_range(Word(1) << (bits - 1), Word((DWord(1) << bits) - 1));
    }
    
    std::tuple<Word, size_t> BarrettRS(Word n) {
        size_t constexpr extra = 3;
        for (size_t k = 0; k < sizeof(DWord) * 8; ++k) {
            if (2 * (k + extra) < sizeof(Word) * 8)
                continue;
            if ((DWord(1) << k) <= DWord(n))
                continue;
            k += extra;
            ASSERT_MSG(2 * k < sizeof(DWord) * 8, "k " + std::to_string(k));
            DWord r = (DWord(1) << (2 * k)) / n;
            ASSERT_MSG(DWord(r) < (DWord(1) << (sizeof(Word) * 8)),
                "k " + std::to_string(k) + " n " + std::to_string(n));
            ASSERT(2 * k >= sizeof(Word) * 8);
            return std::make_tuple(Word(r), size_t(2 * k - sizeof(Word) * 8));
        }
        ASSERT(false);
    }
    
    template <bool Adjust>
    static Word BarrettMod(DWord const & x, Word const & n, Word const & r, size_t s) {
        //return x % n;
        DWord const q = DWord(((TWord(x) * r) >> (sizeof(Word) * 8)) >> s);
        Word t = Word(DWord(x) - q * n);
        if constexpr(Adjust) {
            Word const mask = ~Word(SWord(t - n) >> (sizeof(Word) * 8 - 1));
            t -= mask & n;
        }
        return t;
    }
    
    static Word Adjust(Word const & a, Word const & n) {
        return a >= n ? a - n : a;
    }
    
    Word modNn(DWord const & a) const { return BarrettMod<false>(a, N_, N_br_, N_bs_); }
    Word modNa(DWord const & a) const { return BarrettMod<true>(a, N_, N_br_, N_bs_); }
    Word modQn(DWord const & a) const { return BarrettMod<false>(a, q_, Q_br_, Q_bs_); }
    Word modQa(DWord const & a) const { return BarrettMod<true>(a, q_, Q_br_, Q_bs_); }
    
    static Word mod(DWord const & a, Word const & n) { return a % n; }
    
    static ECPoint base_gen(size_t bits = 128, Word min_order_pfactor = 0) {
        while (true) {
            Word const N = rand_prime(bits);
            if (mod(N, 4) != 3)
                continue;
            Word const
                x0 = rand_range(1, N), y0 = rand_range(1, N), A = rand_range(1, N),
                B = mod(mod(DWord(y0) * y0, N) + N * 2 - mod(DWord(mod(DWord(x0) * x0, N)) * x0, N) - mod(DWord(A) * x0, N), N),
                y0_calc = pow_mod(mod(DWord(y0) * y0, N), (N + 1) >> 2, N);
            if (y0 != y0_calc)
                continue;
            auto const bp = ECPoint(A, B, N, x0, y0, 0, true, true);
            
            auto BpCheckOrder = [&]{
                for (auto e: bp.q_ps())
                    if (e < min_order_pfactor)
                        return false;
                return true;
            };
            
            if (!(bp.q() != 0 && !bp.q_ps().empty() && BpCheckOrder()))
                continue;
            ASSERT(bp.q() > 1 && bp * (bp.q() + 1) == bp);
            return bp;
        }
        ASSERT(false);
    }
    
    ECPoint(Word A, Word B, Word N, Word x, Word y, Word q = 0, bool prepare = true, bool calc_q = false) {
        if (prepare) {
            A = mod(A, N); B = mod(B, N); x = mod(x, N); y = mod(y, N); q = mod(q, N);
            ASSERT(mod(4 * mod(DWord(mod(DWord(A) * A, N)) * A, N) + 27 * mod(DWord(B) * B, N), N) != 0);
            ASSERT(mod(N, 4) == 3);
            if (!(x == 0 && y == 0)) {
                ASSERT(mod(mod(DWord(y) * y, N) + 3 * N - mod(DWord(mod(DWord(x) * x, N)) * x, N) - mod(DWord(A) * x, N) - B, N) == 0);
                ASSERT(y == pow_mod(mod(DWord(mod(DWord(x) * x, N)) * x, N) + mod(DWord(A) * x, N) + B, (N + 1) >> 2, N));
            }
            std::tie(N_br_, N_bs_) = BarrettRS(N);
            if (q != 0)
                std::tie(Q_br_, Q_bs_) = BarrettRS(q);
        }
        std::tie(A_, B_, N_, x_, y_, q_) = std::tie(A, B, N, x, y, q);
        if (calc_q) {
            std::tie(q_, q_ps_) = find_order();
            if (q_ != 0)
                std::tie(Q_br_, Q_bs_) = BarrettRS(q_);
        }
    }
    
    auto copy() const {
        return ECPoint(A_, B_, N_, x_, y_, q_, false);
    }
    
    auto inf() const {
        return ECPoint(A_, B_, N_, 0, 0, q_, false);
    }
    
    static auto const & gen_primes(Word const B) {
        thread_local std::vector<Word> ps = {2, 3};
        
        for (Word p = ps.back() + 2; p <= B; p += 2) {
            bool is_prime = true;
            for (auto const e: ps) {
                if (e * e > p)
                    break;
                if (p % e == 0) {
                    is_prime = false;
                    break;
                }
            }
            if (is_prime)
                ps.push_back(p);
        }
        
        return ps;
    }
    
    std::tuple<Word, std::vector<Word>> find_order(Word _m = 1, std::vector<Word> _ps = {}) const {
        ASSERT(_m <= 2 * N_);
        
        if constexpr(1) {
            auto r = *this;
            try {
                r *= _m;
            } catch (InvError const &) {
                return std::make_tuple(_m, _ps);
            }
            Word const B = 2 * N_;
            for (Word const p: gen_primes(std::llround(std::cbrt(B) + 1))) {
                if (p * p * p > B)
                    break;
                ASSERT(p <= B);
                size_t cnt = 0;
                Word hi = 1;
                try {
                    for (cnt = 1;; ++cnt) {
                        if (hi * p > B) {
                            cnt -= 1;
                            break;
                        }
                        hi *= p;
                        r *= p;
                     }
                } catch (InvError const & ex) {
                    _ps.insert(_ps.begin(), cnt, p);
                    return find_order(hi * _m, _ps);
                }
            }
        } else {
            // Alternative slower way
            auto r = *this;
            for (Word i = 0;; ++i)
                try {
                    r += *this;
                } catch (InvError const &) {
                    _ps.clear();
                    return std::make_tuple(i + 2, _ps);
                }
        }
        
        _ps.clear();
        return std::make_tuple(Word(0), _ps);
    }
    
    static std::tuple<Word, SWord, SWord> EGCD(Word const & a, Word const & b) {
        Word ro = 0, r = 0, qu = 0, re = 0;
        SWord so = 0, s = 0;
        std::tie(ro, r, so, s) = std::make_tuple(a, b, 1, 0);
        while (r != 0) {
            std::tie(qu, re) = std::make_tuple(ro / r, ro % r);
            std::tie(ro, r) = std::make_tuple(r, re);
            std::tie(so, s) = std::make_tuple(s, so - s * SWord(qu));
        }
        SWord const to = (SWord(ro) - SWord(a) * so) / SWord(b);
        return std::make_tuple(ro, so, to);
    }
    
    Word inv(Word a, Word const & n, size_t any_n_q = 0) const {
        ASSERT(n > 0);
        a = any_n_q == 0 ? mod(a, n) : any_n_q == 1 ? modNa(a) : any_n_q == 2 ? modQa(a) : 0;
        auto [gcd, s, t] = EGCD(a, n);
        if (gcd != 1)
            throw InvError(gcd, a, n);
        a = Word(SWord(n) + s);
        a = any_n_q == 0 ? mod(a, n) : any_n_q == 1 ? modNa(a) : any_n_q == 2 ? modQa(a) : 0;
        return a;
    }
    
    Word invN(Word a) const { return inv(a, N_, 1); }
    Word invQ(Word a) const { return inv(a, q_, 2); }
    
    ECPoint & operator += (ECPoint const & o) {
        if (x_ == 0 && y_ == 0) {
            *this = o;
            return *this;
        }
        if (o.x_ == 0 && o.y_ == 0)
            return *this;
        Word const Px = x_, Py = y_, Qx = o.x_, Qy = o.y_;
        Word s = 0;
        if ((Adjust(Px, N_) == Adjust(Qx, o.N_)) && (Adjust(Py, N_) == Adjust(Qy, o.N_)))
            s = modNn(DWord(modNn(DWord(Px) * Px * 3) + A_) * invN(Py * 2));
        else
            s = modNn(DWord(Py + 2 * N_ - Qy) * invN(Px + 2 * N_ - Qx));
        x_ = modNn(DWord(s) * s + 4 * N_ - Px - Qx);
        y_ = modNn(DWord(s) * (Px + 2 * N_ - x_) + 2 * N_ - Py);
        return *this;
    }
    
    ECPoint operator + (ECPoint const & o) const {
        ECPoint c = *this;
        c += o;
        return c;
    }
    
    ECPoint & operator *= (Word k) {
        auto const ok = k;
        ASSERT(k > 0);
        if (k == 1)
            return *this;
        k -= 1;
        auto r = *this, s = *this;
        while (true) {
            if (k & 1) {
                r += s;
                if (k == 1)
                    break;
            }
            k >>= 1;
            s += s;
        }
        if constexpr(0) {
            auto r2 = *this;
            for (u64 i = 1; i < ok; ++i)
                r2 += *this;
            ASSERT(r == r2);
        }
        *this = r;
        return *this;
    }
    
    ECPoint operator * (Word k) const {
        ECPoint r = *this;
        r *= k;
        return r;
    }
    
    bool operator == (ECPoint const & o) const {
        return A_ == o.A_ && B_ == o.B_ && N_ == o.N_ && q_ == o.q_ &&
            Adjust(x_, N_) == Adjust(o.x_, o.N_) && Adjust(y_, N_) == Adjust(o.y_, o.N_);
    }
    
    Word const & q() const { return q_; }
    std::vector<Word> const & q_ps() const { return q_ps_; }
    Word const & x() const { return x_; }
    
private:
    Word A_ = 0, B_ = 0, N_ = 0, q_ = 0, x_ = 0, y_ = 0, N_br_ = 0, Q_br_ = 0;
    size_t N_bs_ = 0, Q_bs_ = 0;
    std::vector<Word> q_ps_;
};

Word pollard_rho_ec_log(ECPoint const & a, ECPoint const & b, ECPoint const & bp) {
    // https://en.wikipedia.org/wiki/Pollard%27s_rho_algorithm_for_logarithms#Algorithm
    
    for (u64 itry = 0;; ++itry) {
        u64 i = 0;
        
        try {
            Word const part_p = bp.rand_prime_range(8, bp.q() >> 4);
            
            auto ModQ = [&](Word n) {
                return n >= bp.q() ? n - bp.q() : n;
            };
            
            auto f = [&](auto const & x) -> ECPoint {
                Word const mod3 = (x.x() % part_p) % 3;
                if (mod3 == 0)
                    return b + x;
                else if (mod3 == 1)
                    return x + x;
                else if (mod3 == 2)
                    return a + x;
                else
                    ASSERT(false);
            };
            
            auto const g = [&](auto const & x, Word n) -> Word {
                Word const mod3 = (x.x() % part_p) % 3;
                if (mod3 == 0)
                    return n;
                else if (mod3 == 1)
                    return ModQ(2 * n);
                else if (mod3 == 2)
                    return ModQ(n + 1);
                else
                    ASSERT(false);
            };
            
            auto const h = [&](auto const & x, Word n) -> Word {
                Word const mod3 = (x.x() % part_p) % 3;
                if (mod3 == 0)
                    return ModQ(n + 1);
                else if (mod3 == 1)
                    return ModQ(2 * n);
                else if (mod3 == 2)
                    return n;
                else
                    ASSERT(false);
            };
            
            Word aim1 = 0, bim1 = 0, a2im2 = 0, b2im2 = 0, ai = 0, bi = 0, a2i = 0, b2i = 0;
            ECPoint xim1 = bp.inf(), x2im2 = bp.inf(), xi = bp.inf(), x2i = bp.inf();
            
            for (i = 1;; ++i) {
                xi = f(xim1);
                ai = g(xim1, aim1);
                bi = h(xim1, bim1);
                
                x2i = f(f(x2im2));
                a2i = g(f(x2im2), g(x2im2, a2im2));
                b2i = h(f(x2im2), h(x2im2, b2im2));
                
                if (xi == x2i)
                    return bp.modQa(DWord(bp.invQ(bp.q() + bi - b2i)) * (bp.q() + a2i - ai));
                
                std::tie(xim1,  aim1, bim1) = std::tie(xi, ai, bi);
                std::tie(x2im2, a2im2, b2im2) = std::tie(x2i, a2i, b2i);
            }
        } catch (ECPoint::InvError const & ex) {
            g_log << "Try " << std::setfill(' ') << std::setw(4) << itry << ", Pollard-Rho failed, invert err at iter "
                << std::setw(7) << i << ", " << ex.what() << std::endl;
        }
    }
}

void test() {
    auto const gtb = std::chrono::high_resolution_clock::now();
    auto Time = [&]() -> double {
        return std::chrono::duration_cast<std::chrono::milliseconds>(
            std::chrono::high_resolution_clock::now() - gtb).count() / 1'000.0;
    };
    double tb = 0;
    size_t constexpr bits = 36;
    g_log << "Generating base point, wait... " << std::flush;
    tb = Time();
    auto const bp = ECPoint::base_gen(bits, 50);
    g_log << "Time " << Time() - tb << " sec" << std::endl;
    g_log << "order " << bp.q() << " = ";
    for (auto e: bp.q_ps())
        g_log << e << " * " << std::flush;
    g_log << std::endl;
    Word const k0 = ECPoint::rand_range(1, bp.q()),
               x  = ECPoint::rand_range(1, bp.q());
    auto a = bp * k0;
    auto b = a * x;
    g_log << "Searching discrete logarithm... " << std::endl;
    tb = Time();
    Word const x_calc = pollard_rho_ec_log(a, b, bp);
    g_log << "Time " << Time() - tb << " sec" << std::endl;
    g_log << "our x " << x << ", found x " << x_calc << std::endl;
    g_log << "equal points: " << std::boolalpha << (a * x == a * x_calc) << std::endl;
}

int main() {
    try {
        test();
    } catch (std::exception const & ex) {
        g_log << "Exception: " << ex.what() << std::endl;
    }
}

输出:

Generating base point, wait... Time 38.932 sec
order 195944962603297 = 401 * 4679 * 9433 * 11071 *
Searching discrete logarithm...
Time 69.791 sec
our x 15520105103514, found x 15520105103514
equal points: true

关于c - 椭圆曲线离散对数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/30249086/

相关文章:

c - 在使用 0 和 memset 初始化之间进行优化

c - 在 while 循环中编辑 int 时 C 中的段错误

pkcs#11 - 带有 pkcs11interop 的热门使用机制 CKM_ECDH1_DERIVE

encryption - 雅可比坐标中的椭圆曲线相加

c++ - 分析永不退出的基于 C 或 C++ 的应用程序

c - 当我在c中添加带有unsigned long的字符数组时会发生什么?

c - 将文件中的数据分解为结构体

c++ - 为什么在插入 map 时出现段错误?

c - 长整数作为 C 中的数组索引给出段错误

android - 如何使用secp256r1类型的椭圆曲线 key 对在Android中加密和解密数据?