TL;DR

  • 入力ccに対して、e=65537e=65537におけるRSAの秘密鍵ddを指定したインデックスiiで区切って前後入れ替えたもの(did_i)を用いてcdimod  Nc^{d_i} \mod Nを計算した結果をくれる
  • RSAの問題だが、NNが不明なのでRSAが乗法的であることを利用し、NNの倍数を複数導出してからGCDをとってNNを復元する
  • d,di+1d, d_{i+1}の関係を利用すると、総当たりによってdid_iの先頭の1桁が求まるのでこれを全てのiiで行うことでddを復元する
  • フラグを暗号化しているeeは65537で無いので求めたddを利用してϕ\phiの候補を絞り、フラグを暗号化している公開鍵に対する秘密鍵を求めて復号する

Prerequisite

  • RSA

Writeup

次のようなスクリプトが動いている

#!/usr/bin/python3
from sys import stdin, stdout, exit
from flag import FLAG
from secrets import randbelow
from gmpy import next_prime

p = int(next_prime(randbelow(2**512)))
q = int(next_prime(randbelow(2**512)))
n = p * q
e = 65537

phi = (p - 1)*(q - 1)
d = int(pow(e, -1, phi))
d_len = len(str(d))

print("encrypted flag", pow(FLAG, 3331646268016923629, n))
stdout.flush()

ctr = 0
def oracle(c, i):
    global ctr
    if ctr > 10 * d_len // 9:
        print("Come on, that was already way too generous...")
        return
    ctr += 1
    rotor = lambda d, i: int(str(d)[i % d_len:] + str(d)[:i % d_len])
    return int(pow(c, rotor(d, i), n))

banner = lambda: stdout.write("""
Pelle's Rotor Supported Arithmetic Oracle
1) Query the oracle with a ciphertext and rotation value.
2) Exit.
""")

banner()
stdout.flush()

choices = {
    1: oracle,
    2: exit
}

while True:
    try:
        choice = stdin.readline()
        print("c:")
        stdout.flush()
        cipher = stdin.readline()
        print("rot:")
        stdout.flush()
        rotation = stdin.readline()
        print(choices.get(int(choice))(int(cipher), int(rotation)))
        stdout.flush()
    except Exception as e:
        stdout.write("%s\n" % e)
        stdout.flush()
        exit()

最初にe=3331646268016923629e=3331646268016923629を用いてRSAで暗号化したフラグをくれる。以後はe=65537e=65537の場合の秘密鍵ddに対して、d_len = len(str(d))で定義されたd_lenを用いて10 * d_len // 9回まで次のようなオラクルに問い合わせることが出来る。

  • 入力: c,ic, i
  • 10進法表記でddの上位ii桁を取り出し、元のddの下の桁とした結果を返す
    • 該当箇所: rotor = lambda d, i: int(str(d)[i % d_len:] + str(d)[:i % d_len])

また、通常のRSA問題とは違って、公開鍵NNが与えられていない。というわけで最初にNNを特定する事を目指す。

これはオラクルの入力ccに自由な入力をいれることが出来るのでc1=c2c3c_1 = c_2c_3となるようなc1,c2,c3c_1, c_2, c_3を同じiiで入力するとrotor(d,i)did_iとおいて、c1dic2dic3dimod  Nc_1^{d_i} \equiv c_2^{d_i}c_3^{d_i} \mod Nとなるから、c1dic2dic3di=kiNc_1^{d_i} - c_2^{d_i}c_3^{d_i} =k_i Nのように左辺がNNの倍数となる。よって、異なるc1,c2,c3c_1, c_2, c_3の組み合わせを2つ用意しc1dic2dic3dic_1^{d_i} - c_2^{d_i}c_3^{d_i}をそれぞれ計算して最大公約数を計算することでNNを得ることが期待出来る。

NNが得られので、次はddを得る事を考える。di,di+1d_i, d_{i+1}には次のような関係がある。

di=x10l+ydi+1=10y+x \begin{aligned} d_i &= x \cdot 10^l + y \cr d_{i+1} &= 10y + x \end{aligned}

ここで、xxは10未満の非負整数とする。また、ddの桁数はNNと同じかやや小さいだけなのでllNNの桁数 - 1になるとする(そうでなかったらそうなるまでスクリプトを回せば良い)。

また、x,yx,yは共に未知数だが、xxが高々10通りしかありえないのに対し、yyの取りうる範囲は非常に広いことから、yyを消去してxxを総当りで特定する事を考える。これによって、次のようになる。

di+1=10(dix10l)+x=10dix(10l+11) \begin{aligned} d_{i+1} &= 10(d_i - x\cdot 10^l) + x \cr &= 10d_i - x(10^{l+1} - 1) \end{aligned}

よって両辺をccの指数とすると次のようになる。

cdi+1c10dicx(10l+11)mod  N c^{d_{i+1}} \equiv \frac{c^{10d_i}}{c^{x(10^{l+1}-1)}} \mod N

これを更に変形すると次のようになる。但し、ri:cdimod  Nr_i :\equiv c^{d_i} \mod Nとおいた。

cx(10l+11)r10ri+1mod  N c^{x(10^{l+1}-1)} \equiv \frac{r^{10}}{r_{i+1}} \mod N

右辺はオラクルの結果から計算することが出来、左辺はxxを総当りすることで候補を列挙出来る。よって、これを比べて一致した際にxxを特定することが出来る。これを全てのiiで行うことでddを復元出来る。

ここまでくればed1ed -1ϕ(N)\phi(N)の倍数となり、ed1=kϕ(N)ed -1 = k\phi(N)となるkkeeと同程度(ddϕ(N)\phi(N)と同程度なので)になるから、kkを総当りすることでϕ(N)\phi(N)を求めることが出来る。これに対してフラグの復号を行って、フラグフォーマットに従っているものがフラグとなる。

Code

from math import gcd
from pwn import remote
import sys
from Crypto.Util.number import long_to_bytes

def oracle(c, i):
    sc.sendline(b"1")
    sc.recvuntil(b"c:\n")
    sc.sendline(str(c).encode())
    sc.recvuntil(b"rot:\n")
    sc.sendline(str(i).encode())
    res = int(sc.recvline())

    return res


args = sys.argv
DEBUG = len(args) > 1 and args[1] == "-d"

sc = remote("localhost", 13337)
sc.recvuntil(b"encrypted flag ")
ct = int(sc.recvline())

# for debug
if DEBUG:
    sc.recvuntil(b"n=")
    true_n = int(sc.recvline())
    sc.recvuntil(b"d=")
    true_d = int(sc.recvline())

sc.recvuntil(b"2) Exit.\n")

cs = [
    [2,3,6],
    [7,13,91]
]
kN = []
for i in range(2):
    r1 = oracle(cs[i][0], 0)
    r2 = oracle(cs[i][1], 0)
    r3 = oracle(cs[i][2], 0)
    kN.append(r1*r2 - r3)

N = gcd(kN[0], kN[1])
N_length = len(str(N))
coef = 10**N_length - 1
c = 2

d_digits = []
r_previous = None
r = None
for i in range(len(str(N))+1):
    r = oracle(c, i)
    if i == 0:
        r_previous = r
        continue

    found = False
    rhs = pow(r_previous, 10, N) * pow(r, -1, N) % N
    for x in range(10):
        lhs = pow(c, x * coef, N)
        if rhs == lhs:
            found = True
            print(f"[{i}]: {x=}")
            break

    if not found:
        print("ha?")
        exit()

    d_digits.append(x)
    r_previous = r

e = 65537
d = int("".join(map(str, d_digits)))
if DEBUG:
    assert d == true_d

k_phi = e*d - 1

phi_cands = []
for k in range(1, e):
    if k_phi % k == 0:
        phi = k_phi // k
        if phi.bit_length() == N.bit_length():
            phi_cands.append(phi)

print(len(phi_cands))
_e = 3331646268016923629
for phi in phi_cands:
    _d = pow(_e, -1, phi)

    pt = pow(ct, _d, N)
    print(long_to_bytes(pt))

Flag

(ローカルでやっただけなので無し)