跳转到内容

RSA

Updated: at 04:30

RSA,Rivest-Shamir-Adleman 算法,是一个常见的非对称加密算法。本文将简明扼要通俗易懂地介绍 RSA 的原理,并给出 Python 实现。

本文同步发表于我的博客 https://clouder0.com/zh-cn/posts/rsa/

Why we need RSA?

加密的需求大家都很熟悉,但非对称加密呢?我们为什么需要非对称加密?

想象以下的场景:

这是非对称加密的两大经典应用场景。

如果只才能对称加密的话,这两种场景都是无法实现的:数字签名自然不必说,加密通讯的话,由于你需要让公众具有加密的能力,但又不希望他们能够解密,自然也需要非对称。

How is this possible?为什么还能做出这种「有两种密码,一个加密一个解密」的神奇算法呢?

这一般都是利用了非对称性。例如:将两个质数乘起来得到结果是简单的,但想要对某个大数做质因数分解,复杂度则极其高。

How RSA works…

RSA 利用的就是质因数分解复杂度的非对称性。

我们首先选择两个足够大的质数,记为 p,qp,q,然后:

好的,密钥生成部分结束了。

接下来,我们将 (n,e)(n, e) 作为加密密钥,(n,d)(n,d) 作为解密密钥。而剩下的 p,q,λ(n)p,q,\lambda(n) 应当保密或者直接扔掉。


然后就是加密了,加密相当的简单啊,加入我们想要传递原文 MM,首先使用 padding 将其变成 mm,满足 0m<n0 \le m < n. 这里的 padding 只要是一种可逆的变换就行了。

然后计算:cme(modn)c \equiv m^e \pmod n,这里的 cc 就是我们的加密结果了。

使用快速幂,可以在较短的时间内完成计算。


解密也相当的简单,我们持有密文 cc,想要获得 padded 后的原文 mm,那么:

cd(me)dm(modn)c^d \equiv (m^e)^d \equiv m \pmod n

这里利用的核心原理是:ed1(modn)ed \equiv 1 \pmod n,实际上这就是 dd 的定义式。

相信大家已经完全理解 RSA 了,笑。

Math behind the scene

让我们思考一下,RSA 算法的执行流程已经讲完了,但它为什么能保证安全性、为什么能保证正确性呢?

RSA 的核心原理是:eedd 只有一个公开。而 medmλ(n)m(modn)m^{ed} \equiv m^{\lambda(n)} \equiv m \pmod n.
这里 edλ(n)(modn)ed \equiv \lambda(n) \pmod n 就是解密密钥 dd 的定义式。而 mλ(n)m(modn)m^{\lambda(n)} \equiv m \pmod n 就是 λ(n)\lambda(n) 的定义式。

实际上,eedd 是相当对称的。假如持有 ee 进行加密,加密后 c=mec=m^e,则 cdmc^d \equiv m. 用 dd 加密也是一样的:c=md,ce=mc=m^d, c^e = m.

也就是——实际上持有公钥的用户也可以既加密、又解密……吗?比如我们原本约定好公钥加密,私钥解密,that’s fine. 但哪天你抽风了说我们换换位置,公钥解密私钥加密,那也是无缝切换。

当然,工程实践上公钥经常取固定的 e=65537e=65537,嘛。


接下来还有一个问题,λ(n)=lcm(p1,q1)\lambda(n) = \operatorname{lcm}(p-1,q-1),为什么就有 mλ(n)m(modn)m^{\lambda(n)} \equiv m \pmod n

根据众所周知的费马小定理,我们知道:当 pp 为素数时,ap11(modp)a^{p-1} \equiv 1 \pmod p.

而当 n=pqn=pq 时,显然 nn 就不是素数了,我们要找到 aλ(n)1(modn)a^{\lambda(n)} \equiv 1 \pmod n,这个时候可以使用欧拉定理:

ab{abmodφ(p),b<φ(p)abmodφ(p)+φ(p),bφ(p)(modp)a^b \equiv \begin{cases} a^{b \bmod \varphi(p)},b < \varphi(p) \\ a^{b \bmod \varphi(p) + \varphi(p)},b \geq \varphi(p) \end{cases} \pmod{p}

其中 φ(p)\varphi(p) 为欧拉函数。欧拉函数满足积性,也就是 φ(pq)=φ(p)×φ(q)\varphi(pq) = \varphi(p) \times \varphi(q). 并且有对于素数 ppφ(p)=p1\varphi(p) = p-1.

那么 φ(n)=φ(pq)=φ(p)×φ(q)=(p1)(q1)\varphi(n) = \varphi(pq) = \varphi(p) \times \varphi(q) = (p-1)(q-1),非常 reasonable.

那么显然,我们就可以得到:

aφ(n)a01(modn)a^{\varphi(n)} \equiv a^0 \equiv 1 \pmod n

这就算是求出了一个满足需要的 λ(n)\lambda(n)…了吗?注意到我们的定义是最小的 mm 使得 am1(modn)a^m \equiv 1 \pmod n,这里的 φ(n)\varphi(n) 未必是最小的。

当然,实际上是不是最小的其实对 RSA 影响不大。

接下来就是 Carmichael function,其计算如下:

λ(n)={φ(n),if n is 1,2,3,4 or an odd prime power12φ(n), if n=2r,r3lcm(λ(n1),,λ(nk)), if n=n1n2nk, where ni are power of distinct prime numbers\lambda(n) = \begin{cases} \varphi(n), &\text{if } n \text{ is }1,2,3,4 \text{ or an odd prime power} \\ \dfrac{1}{2}\varphi(n), &\text{ if } n = 2^r, r \ge 3 \\ \operatorname{lcm}\left( \lambda(n_1),\cdots,\lambda(n_k) \right), &\text{ if } n = n_1n_2\cdots n_k, \text{ where } n_i \text{ are}\\ &\text{ power of distinct prime numbers} \end{cases}

在这里,因为 n=pqn=pqp,qp,q 都是质数,那么 λ(pq)=lcm(φ(p),φ(q))=lcm(p1,q1)\lambda(pq) = \operatorname{lcm}(\varphi(p),\varphi(q))= \operatorname{lcm}(p-1,q-1).

Implementation

涉及到大数运算,人生苦短,我用 Python.

但是 Python 确实不是很快,我决定使用稍微短一些的 pq. 1024bits 吧,这样最终的 nn 就是 2048bits.

以下是核心代码:

import random


def miller_rabin(n: int, k: int):
    """use miller rabin method to test prime."""
  
    if n == 2:
        return True

    if n % 2 == 0:
        return False

    r, s = 0, n - 1
    while s % 2 == 0:
        r += 1
        s //= 2
    for _ in range(k):
        a = random.randrange(2, n - 1)
        x = pow(a, s, n)
        if x == 1 or x == n - 1:
            continue
        for _ in range(r - 1):
            x = pow(x, 2, n)
            if x == n - 1:
                break
        else:
            return False
    return True


def exgcd(a: int, b: int):
    """exgcd to cacl inverse."""
    if b == 0:
        return a, 1, 0
    d, x, y = exgcd(b, a % b)
    x, y = y, x - (a // b) * y
    return d, x, y


def is_prime(n: int) -> bool:
    return miller_rabin(n, 40)

def inv(a: int, m: int) -> int:
    """calc modular inverse."""
    d, x, y = exgcd(a, m)
    if d != 1:
        raise RuntimeError("modular inverse does not exist")
    return x % m


def rsa_encrypt(m: int, e: int, n: int) -> int:
    return pow(m, e, n)


def rsa_decrypt(c: int, d: int, n: int) -> int:
    return pow(c, d, n)


def gcd(a: int, b: int) -> int:
    if b == 0:
        return a
    return gcd(b, a % b)


def lcm(a: int, b: int) -> int:
    return a * b // gcd(a, b)


def rsa_gen(p: int, q: int) -> tuple[int, int, int]:
    n = p * q
    l = lcm(p - 1, q - 1)
    e = 65537
    d = inv(e, l)
    return n, e, d


def get_big_prime():
    while True:
        p = random.getrandbits(1024)
        if is_prime(p):
            return p


def get_pq() -> tuple[int, int]:
    return get_big_prime(), get_big_prime()

def main():
    n, e, d = rsa_gen(*get_pq())
    print("n =", n)
    print("e =", e)
    print("d =", d)
    origin = int(input("origin: "))
    c = rsa_encrypt(origin, e, n)
    print("c =", c)
    print("origin =", rsa_decrypt(c, d, n))

    assert origin == rsa_decrypt(c, d, n)

    print("OK")


if __name__ == "__main__":
    main()

一般而言,RSA 的速度较为缓慢,我们可以将 RSA 和对称加密配合使用,比如说用 RSA 传递对称加密的密钥,以实现加密通讯。

处理的长度过长的时候,需要分块。emmm,注意到计算在 modn\bmod n 下进行,需要分块后比 nn 小。

def encrypt_file(n, e):
    with open("input.txt", "rb") as f:
        data = f.read()
    # chunking by 255 bytes
    chunks = [data[i : i + 255] for i in range(0, len(data), 255)]

    with open("output.txt", "wb") as f:
        for chunk in chunks:
            m = int.from_bytes(chunk, "little")
            c = rsa_encrypt(m, e, n).to_bytes(512, "little")
            f.write(c)


def decrypt_file(n, d):
    with open("output.txt", "rb") as f:
        data = f.read()

    chunks = [data[i : i + 512] for i in range(0, len(data), 512)]
  
    with open("output_de.txt", "wb") as f:
        for chunk in chunks[:-1]:
            c = int.from_bytes(chunk, "little")
            m = rsa_decrypt(c, d, n).to_bytes(255, "little")
            f.write(m)

        c = int.from_bytes(chunks[-1], "little")
        m = rsa_decrypt(c, d, n).to_bytes(255, "little")
        # trim trailing zeros
        while m[-1] == 0:
            m = m[:-1]
        f.write(m)
          
def test_file():
    n, e, d = rsa_gen(*get_pq())
    encrypt_file(n, e)
    decrypt_file(n, d)


上一篇
DecisionTree
下一篇
test