TL;DR

  • RSAの復号オラクルが動いているが、秘密鍵の全てのbitが毎度異なる値でマスクされている
  • 1bitのXORは四則演算に直して計算すると元の秘密鍵のbitを変数とした式が立つ
  • 指数法則をいい感じに使うと行列の問題に落ちるので基底を入手し、ある1つのbit以外0であるような係数を用意すると、それを指数として利用することで、対応する秘密鍵のbitが0なら1が得られるような式になる
  • これを判定条件として秘密鍵の各bitを特定して秘密鍵を復元する
  • 以降、どのような指数が暗号文に掛けられたか分かるので互いに素なものを用意すれば拡張ユークリッド互除法を用いて暗号文の指数を1にして暗号文を得て、秘密鍵を指数としてべき乗すれば平文が手に入る

Prerequisite

  • 1bit排他的論理和の四則演算表現
  • 指数法則
  • 線形代数

Writeup

次のようなスクリプトが動いている。

from secrets import randbits
from Crypto.Util.number import getPrime  # pycryptodome

NBITS = 1024
D_NBITS = 128  # small `d` makes decryption faster


class Cipher:
    def __init__(self):
        p = getPrime(NBITS // 2)
        q = getPrime(NBITS // 2)
        self.n = p * q
        self.d = getPrime(D_NBITS)
        self.e = pow(self.d, -1, (p - 1) * (q - 1))

    def encrypt(self, m: int) -> int:
        assert m < self.n
        return pow(m, self.e, self.n)

    def faultily_decrypt(self, c: int):
        assert c < self.n
        fault_vector = randbits(D_NBITS)
        return fault_vector, pow(c, self.d ^ fault_vector, self.n)


def main():
    from secret import FLAG
    cipher = Cipher()
    c = cipher.encrypt(int.from_bytes(FLAG.encode(), "big"))

    for _ in range(2022):
        line = input()
        print(cipher.faultily_decrypt(c if line == 'c' else int(line)))


if __name__ == '__main__':
    main()

RSAの復号オラクルが動いている。但し、128bitの秘密鍵$d$に対して、同じく128bitの数値$v$でマスクした値を指数としている。具体的にはオラクルに送信した暗号文$c$に対して$c^{d\oplus v} \mod n$と$v$が返される。

ところでRSAが関連する問題なのに$n$も$e$も与えられない。これでは何も歯が立たない予感がするのでひとまず$n$を特定する。

オラクルに$-1$を与えると$-1^{d\oplus v} \mod n$が返されるが、もし$d \oplus v$が奇数ならこれは$-1 \equiv n-1 \mod n$になり、一方偶数なら1が返される。よって前者の場合を引いたらそれに1を足すことで$n$を取得出来る。

オラクルからは2つの値が得られるが、復号出来ていない平文の方を$m$とおく。ここで、$d$と$v$のibit目を$b_i, v_i$とおくと次のようになる。

$$ m \equiv c^{\sum_{i=0}^{127}({b_i \oplus v_i})2^i} \equiv \prod_{i=0}^{127} c^{(b_i\oplus v_i)2^i} \equiv \prod_{i=0}^{127} {c_i'}^{(b_i \oplus v_i)} \mod n $$

ここで$c_i' :\equiv c^{2^i} \mod n$とおいた。

ところで、1bitの排他的論理和$\oplus$について、次のような関係がある。

$$ b_1 \oplus b_2 = b_1+b_2 - 2b_1b_2 $$

これを代入してあげると先程の式は次のようになる。

$$ m \equiv \prod_{i=0}^{127} {c_i'}^{b_i(1-2v_i)}{c_i'}^{v_i} \mod n $$

$v_i$は既知なので${c_i'}^{v_i}$は移項して左辺に持っていくことが出来る。また、$b_i(1-2v_i)$は$b_i$または$-b_i$になり、これもどちらになるかは$v_i$が既知なので判明する。よって、$b_i^{s_i}$と$s_i =\pm 1$を用いて表すことにする。これで次のような式が得られる。

$$ m' \coloneqq m \prod_{i=0}^{127} {c_i'}^{-v_i} \equiv \prod_{i=0}^{127} ({c_i'}^{b_i})^{s_i} \mod n $$

復号オラクルには何回も問い合わせられるので、常に同じ$c$で何度も問い合わせて得られた$j$回目の結果から計算された$m'$を$m'_j$とおく。これに対して指数法則を用いると、$j$回目のオラクルで計算された$s_i$を$s_i^{(j)}$として次が成り立つ。

$$ \prod_j {m_j'}^{x_j} \equiv \prod_{i=0}^{127} ({c_i'}^{b_i})^{\sum_{j} x_js_i^{(j)}} \mod n $$

右辺の${c_i'}^{b_i}$の指数に関して線形性が成り立つことから、左辺の積は右辺の指数の和に、左辺のべき乗は右辺の指数のスカラー倍に対応する。よって各$j$に対して$s_i$を並べたベクトル$(s_0^{(j)}, \dots, s_{127}^{(j)})$を考え、適切に基底を選んで、スカラー倍と加算を繰り返して指数$\sum_{j} x_js_i^{(j)}$を生成すれば、$s_i = \pm 1$にかぎらず、任意の$s_i$に対して$\prod_{i=0}^{127}({c_i'}^{b_i})^{s_i}$を計算できる予感がする。基底とするには128個の一次独立なベクトルで十分である。

正確には$(s_0, s_1, s_2, \dots, s_{127}) = (1,0,0,\dots,0)$のようなものを生成するのは難しい、というのも係数$(x_0,x_1, x_2, \dots, x_{127})$として有理数を許すなら可能だが、整数だけならそれは格子となり、$\mathbb Z^{128}$の部分集合ではあるが$\mathbb Z^{128}$に一致するような格子を得ることは難しい。$\mathbb Z/n\mathbb Z$上では整数乗は可能だが、有理数乗は難しいため、係数を求めたとしてもそれを左辺の$m'_j$の指数とすることは出来ない。

それでも$(1,0,0,\dots,0)$となる有理数係数$x_0, x_1, x_2, \dots, x_{127}$が得られたのなら、任意の成分の分母を除去する数(任意成分の分母の最小公倍数)を掛けた整数係数$ax_0, ax_1, ax_2, \dots, ax_{127}$によって$(a, 0, 0,\dots, 0)$が生成される。この例では右辺は$(c_0'^{b_0})^{a} \equiv {c_0'}^{ab_0} \mod n$となるが、$b_i$は0か1であり、$b_i=0$ならこの値は1になるはずである。

よって、このような1成分(そこの添字を$i$とする)だけ0以外の整数となるような$s_0, s_1, \dots, s_{127}$を生成する整数係数$x_0, x_1, \dots, x_{127}$を選んで、左辺を計算し、その値が1かどうかを判定することで$b_i$を特定出来る。これを全ての$i$に対して行えば$d$を復元出来る。

$d$を入手出来たので、以後オラクルから$m$を得た時、$m\equiv c^{d\oplus v} \mod n$における$d \oplus v$を特定出来る。

ここで$d_i = d \oplus v_i$とおくと、もし$d_i, d_j$が互いに素なら$sd_i+td_j = 1$となる整数$s,t$を求めることが出来る。よって、$m_i^sm_j^t \equiv c^{sd_i+td_j} \equiv c \mod n$となるので、これを$d$乗すれば平文が手に入る。

オラクルでは文字列"c"を送信することで、フラグを平文とした暗号文(未知)が送信されるので、その結果でこの方法を用いればフラグが手に入る。

Code

from pwn import process
import sys
from Crypto.Util.number import long_to_bytes


def f_decrypt(c=None):
    if c is not None:
        sc.sendline(str(c).encode())
    else:
        sc.sendline("c".encode())

    v, _m = eval(sc.recvline())

    return v, _m


def v_to_svec(v):
    ret = []
    for _ in range(128):
        if v & 1:
            ret.append(-1)
        else:
            ret.append(1)
        v >>= 1

    return vector(ZZ, ret)


DEBUG = False
if len(sys.argv) > 1 and sys.argv[1] == "-d":
    DEBUG = True

if DEBUG:
    sc = process(["python3", "test_fault.py"])
else:
    sc = process(["python3", "fault.py"])


if DEBUG:
    sc.recvuntil(b"n = ")
    _n = int(sc.recvline())
    sc.recvuntil(b"d = ")
    _d = int(sc.recvline())


while True:
    v, _m = f_decrypt(-1)

    if _m != 1:
        n = _m + 1
        break


if DEBUG:
    assert n == _n

c = 2

_cs = []
for i in range(128):
    _cs.append(power_mod(c, 2^i, n))

_cs = vector(ZZ, _cs)

M = []
_ms = []
for i in range(128):
    v, _m = f_decrypt(c)
    s = v_to_svec(v)
    M.append(s)

    for i in range(128):
        if s[i] == -1:
            _m *= power_mod(_cs[i], -1, n)
            _m %= n

    _ms.append(_m)

print("[+] decryptions are completed")
M = matrix(ZZ, M)

b_d = ""
for i in range(128):
    target = [0 for j in range(i)] + [1] + [0 for j in range(128 - i - 1)]
    target = vector(ZZ, target)
    x = M.solve_left(target)

    denoms = set()
    for y in x:
        denoms.add(y.denom())

    max_denom = max(denoms)

    x = max_denom * x

    c_d = 1
    for j in range(128):
        c_d *= power_mod(_ms[j], int(x[j]), n)
        c_d %= n

    print(i, c_d)
    if c_d != 1:
        b_d = "1" + b_d
    else:
        b_d = "0" + b_d

d = int(b_d, 2)
if DEBUG:
    assert d == _d

print(f"[+] d is recovered: {d}")

_ds = []
_ms = []
while True:
    v, _m = f_decrypt()
    _d = v ^^ d
    _ds.append(_d)
    _ms.append(_m)
    for i, other_d in enumerate(_ds):
        g, s, t = xgcd(other_d, _d)
        if g == 1:
            c = power_mod(_ms[i], s, n) * power_mod(_m, t, n) % n
            print(long_to_bytes(power_mod(c, d, n)))
            exit()

Flag

ローカルでやっただけ

Resources