Midnight Sun CTF Quals 2022 - Pelle's Rotor-Supported Arithmetic
TL;DR
- 入力$c$に対して、$e=65537$におけるRSAの秘密鍵$d$を指定したインデックス$i$で区切って前後入れ替えたもの($d_i$)を用いて$c^{d_i} \mod N$を計算した結果をくれる
- RSAの問題だが、$N$が不明なのでRSAが乗法的であることを利用し、$N$の倍数を複数導出してからGCDをとって$N$を復元する
- $d, d_{i+1}$の関係を利用すると、総当たりによって$d_i$の先頭の1桁が求まるのでこれを全ての$i$で行うことで$d$を復元する
- フラグを暗号化している$e$は65537で無いので求めた$d$を利用して$\phi$の候補を絞り、フラグを暗号化している公開鍵に対する秘密鍵を求めて復号する
Prerequisite
- RSA
Writeup
次のようなスクリプトが動いている
#!/usr/bin/python3
from sys import stdin, stdout, exit
from flag import FLAG
from secrets import randbelow
from gmpy import next_prime
p = int(next_prime(randbelow(2**512)))
q = int(next_prime(randbelow(2**512)))
n = p * q
e = 65537
phi = (p - 1)*(q - 1)
d = int(pow(e, -1, phi))
d_len = len(str(d))
print("encrypted flag", pow(FLAG, 3331646268016923629, n))
stdout.flush()
ctr = 0
def oracle(c, i):
global ctr
if ctr > 10 * d_len // 9:
print("Come on, that was already way too generous...")
return
ctr += 1
rotor = lambda d, i: int(str(d)[i % d_len:] + str(d)[:i % d_len])
return int(pow(c, rotor(d, i), n))
banner = lambda: stdout.write("""
Pelle's Rotor Supported Arithmetic Oracle
1) Query the oracle with a ciphertext and rotation value.
2) Exit.
""")
banner()
stdout.flush()
choices = {
1: oracle,
2: exit
}
while True:
try:
choice = stdin.readline()
print("c:")
stdout.flush()
cipher = stdin.readline()
print("rot:")
stdout.flush()
rotation = stdin.readline()
print(choices.get(int(choice))(int(cipher), int(rotation)))
stdout.flush()
except Exception as e:
stdout.write("%s\n" % e)
stdout.flush()
exit()
最初に$e=3331646268016923629$を用いてRSAで暗号化したフラグをくれる。以後は$e=65537$の場合の秘密鍵$d$に対して、d_len = len(str(d))
で定義されたd_len
を用いて10 * d_len // 9
回まで次のようなオラクルに問い合わせることが出来る。
- 入力: $c, i$
- 10進法表記で$d$の上位$i$桁を取り出し、元の$d$の下の桁とした結果を返す
- 該当箇所:
rotor = lambda d, i: int(str(d)[i % d_len:] + str(d)[:i % d_len])
- 該当箇所:
また、通常のRSA問題とは違って、公開鍵$N$が与えられていない。というわけで最初に$N$を特定する事を目指す。
これはオラクルの入力$c$に自由な入力をいれることが出来るので$c_1 = c_2c_3$となるような$c_1, c_2, c_3$を同じ$i$で入力するとrotor(d,i)
を$d_i$とおいて、$c_1^{d_i} \equiv c_2^{d_i}c_3^{d_i} \mod N$となるから、$c_1^{d_i} - c_2^{d_i}c_3^{d_i} =k_i N$のように左辺が$N$の倍数となる。よって、異なる$c_1, c_2, c_3$の組み合わせを2つ用意し$c_1^{d_i} - c_2^{d_i}c_3^{d_i}$をそれぞれ計算して最大公約数を計算することで$N$を得ることが期待出来る。
$N$が得られので、次は$d$を得る事を考える。$d_i, d_{i+1}$には次のような関係がある。
$$ \begin{aligned} d_i &= x \cdot 10^l + y \cr d_{i+1} &= 10y + x \end{aligned} $$
ここで、$x$は10未満の非負整数とする。また、$d$の桁数は$N$と同じかやや小さいだけなので$l$は$N$の桁数 - 1になるとする(そうでなかったらそうなるまでスクリプトを回せば良い)。
また、$x,y$は共に未知数だが、$x$が高々10通りしかありえないのに対し、$y$の取りうる範囲は非常に広いことから、$y$を消去して$x$を総当りで特定する事を考える。これによって、次のようになる。
$$ \begin{aligned} d_{i+1} &= 10(d_i - x\cdot 10^l) + x \cr &= 10d_i - x(10^{l+1} - 1) \end{aligned} $$
よって両辺を$c$の指数とすると次のようになる。
$$ c^{d_{i+1}} \equiv \frac{c^{10d_i}}{c^{x(10^{l+1}-1)}} \mod N $$
これを更に変形すると次のようになる。但し、$r_i :\equiv c^{d_i} \mod N$とおいた。
$$ c^{x(10^{l+1}-1)} \equiv \frac{r^{10}}{r_{i+1}} \mod N $$
右辺はオラクルの結果から計算することが出来、左辺は$x$を総当りすることで候補を列挙出来る。よって、これを比べて一致した際に$x$を特定することが出来る。これを全ての$i$で行うことで$d$を復元出来る。
ここまでくれば$ed -1$が$\phi(N)$の倍数となり、$ed -1 = k\phi(N)$となる$k$は$e$と同程度($d$が$\phi(N)$と同程度なので)になるから、$k$を総当りすることで$\phi(N)$を求めることが出来る。これに対してフラグの復号を行って、フラグフォーマットに従っているものがフラグとなる。
Code
from math import gcd
from pwn import remote
import sys
from Crypto.Util.number import long_to_bytes
def oracle(c, i):
sc.sendline(b"1")
sc.recvuntil(b"c:\n")
sc.sendline(str(c).encode())
sc.recvuntil(b"rot:\n")
sc.sendline(str(i).encode())
res = int(sc.recvline())
return res
args = sys.argv
DEBUG = len(args) > 1 and args[1] == "-d"
sc = remote("localhost", 13337)
sc.recvuntil(b"encrypted flag ")
ct = int(sc.recvline())
# for debug
if DEBUG:
sc.recvuntil(b"n=")
true_n = int(sc.recvline())
sc.recvuntil(b"d=")
true_d = int(sc.recvline())
sc.recvuntil(b"2) Exit.\n")
cs = [
[2,3,6],
[7,13,91]
]
kN = []
for i in range(2):
r1 = oracle(cs[i][0], 0)
r2 = oracle(cs[i][1], 0)
r3 = oracle(cs[i][2], 0)
kN.append(r1*r2 - r3)
N = gcd(kN[0], kN[1])
N_length = len(str(N))
coef = 10**N_length - 1
c = 2
d_digits = []
r_previous = None
r = None
for i in range(len(str(N))+1):
r = oracle(c, i)
if i == 0:
r_previous = r
continue
found = False
rhs = pow(r_previous, 10, N) * pow(r, -1, N) % N
for x in range(10):
lhs = pow(c, x * coef, N)
if rhs == lhs:
found = True
print(f"[{i}]: {x=}")
break
if not found:
print("ha?")
exit()
d_digits.append(x)
r_previous = r
e = 65537
d = int("".join(map(str, d_digits)))
if DEBUG:
assert d == true_d
k_phi = e*d - 1
phi_cands = []
for k in range(1, e):
if k_phi % k == 0:
phi = k_phi // k
if phi.bit_length() == N.bit_length():
phi_cands.append(phi)
print(len(phi_cands))
_e = 3331646268016923629
for phi in phi_cands:
_d = pow(_e, -1, phi)
pt = pow(ct, _d, N)
print(long_to_bytes(pt))
Flag
(ローカルでやっただけなので無し)