TL;DR

  • 入力$c$に対して、$e=65537$におけるRSAの秘密鍵$d$を指定したインデックス$i$で区切って前後入れ替えたもの($d_i$)を用いて$c^{d_i} \mod N$を計算した結果をくれる
  • RSAの問題だが、$N$が不明なのでRSAが乗法的であることを利用し、$N$の倍数を複数導出してからGCDをとって$N$を復元する
  • $d, d_{i+1}$の関係を利用すると、総当たりによって$d_i$の先頭の1桁が求まるのでこれを全ての$i$で行うことで$d$を復元する
  • フラグを暗号化している$e$は65537で無いので求めた$d$を利用して$\phi$の候補を絞り、フラグを暗号化している公開鍵に対する秘密鍵を求めて復号する

Prerequisite

  • RSA

Writeup

次のようなスクリプトが動いている

#!/usr/bin/python3
from sys import stdin, stdout, exit
from flag import FLAG
from secrets import randbelow
from gmpy import next_prime

p = int(next_prime(randbelow(2**512)))
q = int(next_prime(randbelow(2**512)))
n = p * q
e = 65537

phi = (p - 1)*(q - 1)
d = int(pow(e, -1, phi))
d_len = len(str(d))

print("encrypted flag", pow(FLAG, 3331646268016923629, n))
stdout.flush()

ctr = 0
def oracle(c, i):
    global ctr
    if ctr > 10 * d_len // 9:
        print("Come on, that was already way too generous...")
        return
    ctr += 1
    rotor = lambda d, i: int(str(d)[i % d_len:] + str(d)[:i % d_len])
    return int(pow(c, rotor(d, i), n))

banner = lambda: stdout.write("""
Pelle's Rotor Supported Arithmetic Oracle
1) Query the oracle with a ciphertext and rotation value.
2) Exit.
""")

banner()
stdout.flush()

choices = {
    1: oracle,
    2: exit
}

while True:
    try:
        choice = stdin.readline()
        print("c:")
        stdout.flush()
        cipher = stdin.readline()
        print("rot:")
        stdout.flush()
        rotation = stdin.readline()
        print(choices.get(int(choice))(int(cipher), int(rotation)))
        stdout.flush()
    except Exception as e:
        stdout.write("%s\n" % e)
        stdout.flush()
        exit()

最初に$e=3331646268016923629$を用いてRSAで暗号化したフラグをくれる。以後は$e=65537$の場合の秘密鍵$d$に対して、d_len = len(str(d))で定義されたd_lenを用いて10 * d_len // 9回まで次のようなオラクルに問い合わせることが出来る。

  • 入力: $c, i$
  • 10進法表記で$d$の上位$i$桁を取り出し、元の$d$の下の桁とした結果を返す
    • 該当箇所: rotor = lambda d, i: int(str(d)[i % d_len:] + str(d)[:i % d_len])

また、通常のRSA問題とは違って、公開鍵$N$が与えられていない。というわけで最初に$N$を特定する事を目指す。

これはオラクルの入力$c$に自由な入力をいれることが出来るので$c_1 = c_2c_3$となるような$c_1, c_2, c_3$を同じ$i$で入力するとrotor(d,i)を$d_i$とおいて、$c_1^{d_i} \equiv c_2^{d_i}c_3^{d_i} \mod N$となるから、$c_1^{d_i} - c_2^{d_i}c_3^{d_i} =k_i N$のように左辺が$N$の倍数となる。よって、異なる$c_1, c_2, c_3$の組み合わせを2つ用意し$c_1^{d_i} - c_2^{d_i}c_3^{d_i}$をそれぞれ計算して最大公約数を計算することで$N$を得ることが期待出来る。

$N$が得られので、次は$d$を得る事を考える。$d_i, d_{i+1}$には次のような関係がある。

$$ \begin{aligned} d_i &= x \cdot 10^l + y \cr d_{i+1} &= 10y + x \end{aligned} $$

ここで、$x$は10未満の非負整数とする。また、$d$の桁数は$N$と同じかやや小さいだけなので$l$は$N$の桁数 - 1になるとする(そうでなかったらそうなるまでスクリプトを回せば良い)。

また、$x,y$は共に未知数だが、$x$が高々10通りしかありえないのに対し、$y$の取りうる範囲は非常に広いことから、$y$を消去して$x$を総当りで特定する事を考える。これによって、次のようになる。

$$ \begin{aligned} d_{i+1} &= 10(d_i - x\cdot 10^l) + x \cr &= 10d_i - x(10^{l+1} - 1) \end{aligned} $$

よって両辺を$c$の指数とすると次のようになる。

$$ c^{d_{i+1}} \equiv \frac{c^{10d_i}}{c^{x(10^{l+1}-1)}} \mod N $$

これを更に変形すると次のようになる。但し、$r_i :\equiv c^{d_i} \mod N$とおいた。

$$ c^{x(10^{l+1}-1)} \equiv \frac{r^{10}}{r_{i+1}} \mod N $$

右辺はオラクルの結果から計算することが出来、左辺は$x$を総当りすることで候補を列挙出来る。よって、これを比べて一致した際に$x$を特定することが出来る。これを全ての$i$で行うことで$d$を復元出来る。

ここまでくれば$ed -1$が$\phi(N)$の倍数となり、$ed -1 = k\phi(N)$となる$k$は$e$と同程度($d$が$\phi(N)$と同程度なので)になるから、$k$を総当りすることで$\phi(N)$を求めることが出来る。これに対してフラグの復号を行って、フラグフォーマットに従っているものがフラグとなる。

Code

from math import gcd
from pwn import remote
import sys
from Crypto.Util.number import long_to_bytes

def oracle(c, i):
    sc.sendline(b"1")
    sc.recvuntil(b"c:\n")
    sc.sendline(str(c).encode())
    sc.recvuntil(b"rot:\n")
    sc.sendline(str(i).encode())
    res = int(sc.recvline())

    return res


args = sys.argv
DEBUG = len(args) > 1 and args[1] == "-d"

sc = remote("localhost", 13337)
sc.recvuntil(b"encrypted flag ")
ct = int(sc.recvline())

# for debug
if DEBUG:
    sc.recvuntil(b"n=")
    true_n = int(sc.recvline())
    sc.recvuntil(b"d=")
    true_d = int(sc.recvline())

sc.recvuntil(b"2) Exit.\n")

cs = [
    [2,3,6],
    [7,13,91]
]
kN = []
for i in range(2):
    r1 = oracle(cs[i][0], 0)
    r2 = oracle(cs[i][1], 0)
    r3 = oracle(cs[i][2], 0)
    kN.append(r1*r2 - r3)

N = gcd(kN[0], kN[1])
N_length = len(str(N))
coef = 10**N_length - 1
c = 2

d_digits = []
r_previous = None
r = None
for i in range(len(str(N))+1):
    r = oracle(c, i)
    if i == 0:
        r_previous = r
        continue

    found = False
    rhs = pow(r_previous, 10, N) * pow(r, -1, N) % N
    for x in range(10):
        lhs = pow(c, x * coef, N)
        if rhs == lhs:
            found = True
            print(f"[{i}]: {x=}")
            break

    if not found:
        print("ha?")
        exit()

    d_digits.append(x)
    r_previous = r

e = 65537
d = int("".join(map(str, d_digits)))
if DEBUG:
    assert d == true_d

k_phi = e*d - 1

phi_cands = []
for k in range(1, e):
    if k_phi % k == 0:
        phi = k_phi // k
        if phi.bit_length() == N.bit_length():
            phi_cands.append(phi)

print(len(phi_cands))
_e = 3331646268016923629
for phi in phi_cands:
    _d = pow(_e, -1, phi)

    pt = pow(ct, _d, N)
    print(long_to_bytes(pt))

Flag

(ローカルでやっただけなので無し)