Prerequisite

  • Coppersmith's Attack
  • MT19937の予測

Writeup

次のようなスクリプトが動いている

from random import getrandbits
from Crypto.Util.number import getPrime, long_to_bytes, bytes_to_long

def keygen(): # normal rsa key generation
  primes = []
  e = 3
  for _ in range(2):
    while True:
      p = getPrime(1024)
      if (p - 1) % 3:
        break
    primes.append(p)
  return e, primes[0] * primes[1]

def pad(m, n): # pkcs#1 v1.5
  ms = long_to_bytes(m)
  ns = long_to_bytes(n)
  if len(ms) >= len(ns) - 11:
    return -1
  padlength = len(ns) - len(ms) - 3
  ps = long_to_bytes(getrandbits(padlength * 8)).rjust(padlength, b"\x00")
  return int.from_bytes(b"\x00\x02" + ps + b"\x00" + ms, "big")

def encrypt(m, e, n): # standard rsa
  res = pad(m, n)
  if res != -1:
    print(f"c: {pow(res, e, n)}")
  else:
    print("error :(", "message too long")

menu = """
[1] enc()
[2] enc(flag)
[3] quit
"""[1:]

e, n = keygen()
print(f"e: {e}")
print(f"n: {n}")
assert len(open("/challenge/flag.txt", "rb").read()) < 55

while True:
  try:
    print(menu)
    opt = input("opt: ")
    if opt == "1":
      encrypt(int(input("msg: ")), e, n)
    elif opt == "2":
      encrypt(bytes_to_long(open("/challenge/flag.txt", "rb").read()), e, n)
    elif opt == "3":
      print("bye")
      exit(0)
    else:
      print("idk")
  except Exception as e:
    print("error :(", e)

PKCS#1 v1.5で平文にパディングを施しているRSAが暗号化に用いられ次のコマンドが実行出来る

  1. enc(): (長さに制限はあるが)任意の平文の暗号文を取得
  2. enc(flag): フラグの暗号文を取得
  3. quit: 終了

$e=3$と小さいが、パディングが厄介なせいでHastad Broadcast Attackや単純なCoppersmith's Attackは行えない。一見無敵なように見えるが、よく見るとパディングの乱数部分の生成にはgetrandbits()を利用しているため、パディングを十分な数だけ特定出来れば、以降のパディングを完全に特定出来てCoppersmith's Attackが使えそうな予感がする。

(※正直に白状するとmystizさんのCrypto in CTF: Q3 2021 :: Mystifyを眺めてこの問題を見つけた時にmt19937のタグを見てしまったので、randomモジュールを利用するんだろうと当たりを付けていた)

ここで、平文の暗号化コマンドで暗号化する平文はこちらから指定可能、つまり既知であるため、結果からパディングを未知数とする3次の合同方程式が得られる。パディングの長さは、平文の長さを調整することで自由に設定可能なため、Coppersmith's Attackで求められる程度にすればパディングを特定出来る。

$n$は2048bitであるから、680bitぐらいまでのパディングは求めることが出来る。しかし、計算時間が長くなるので直ぐ終える事が出来る256bitにパディングを設定する。これはMT19937の乱数生成8回分に相当することから78回暗号化を行ってその度にCoppersmith's Attackでパディングを求めれば624個の乱数が手に入り、以降の乱数を完全に予測可能になる。

これでフラグの暗号化を行うと、ここで使われたパディングが特定になるため、今度は逆に平文が未知数の合同方程式が立つ。フラグの長さは高々440bitで$N$の大きさは2048btなので、Coppersmith's Attackで求めることが出来る。

なお、フラグの正確な長さがわからないとパディングがどうなるかはわからないが、フラグの長さとしてありえるものが55未満なのでこれは総当り可能。

Code

from pwn import remote
from Crypto.Util.number import long_to_bytes, bytes_to_long
import random


def choice(c):
    sc.recvuntil(b"opt: ")
    sc.sendline(str(c).encode())


def enc(m):
    choice(1)
    sc.recvuntil(b"msg: ")
    sc.sendline(str(m).encode())
    sc.recvuntil(b"c: ")
    return int(sc.recvline())


def make_poly(m, c):
    l_m = len(long_to_bytes(m))
    l_n = len(long_to_bytes(n))
    padlength = l_n - l_m - 3
    f = (2**((l_n - 2) * 8 + 1) + x_pad * (2**((l_m + 1) * 8)) + m)^3 - c
    return f.monic()


def pad_split(pad):
    ret = []
    for _ in range(8):
        ret.append(pad & 0xffffffff)
        pad >>= 32

    return ret


def untemper(x):
    x ^^= (x >> 18)
    x ^^= ((x << 15) & 0xefc60000)
    x_bottom_14 = (x ^^ (x << 7) & 0x9d2c5680) # & ((1 << 14) - 1)
    x_bottom_21 = (x ^^ (x_bottom_14 << 7) & 0x9d2c5680) # & ((1 << 21) - 1)
    x_bottom_28 = (x ^^ (x_bottom_21 << 7) & 0x9d2c5680) # & ((1 << 28) - 1)
    x ^^= (x_bottom_28 << 7) & 0x9d2c5680
    x_top_22 = x ^^ (x >> 11)
    x ^^= (x_top_22 >> 11)

    return int(x)


e = 3

sc = remote("localhost", 13337)
sc.recvuntil(b"n: ")
n = int(sc.recvline())
assert len(long_to_bytes(n)) == 256
print(n.bit_length())
PR.<x_pad> = PolynomialRing(Zmod(n))

state = []
for i in range(78):
    while True:
        m = randint(1, 2^1768)
        if len(long_to_bytes(m)) == 221:
            break

    c = enc(m)
    f = make_poly(m, c)
    rs = f.small_roots()
    if len(rs) != 1:
        print("???")
        exit()

    pad = int(rs[0])
    print(i, pad.bit_length())

    state = state + list(map(untemper, pad_split(pad)))

mt_state = tuple(state[:624] + [624])
mt_state = (3, mt_state, None)
random.setstate(mt_state)

# check
predict_pad = random.getrandbits(256)

while True:
    m = randint(1, 2^1768)
    if len(long_to_bytes(m)) == 221:
        break

c = enc(m)
f = make_poly(m, c)
pad = f.small_roots()[0]
assert int(pad) == predict_pad

state = random.getstate()
choice(2)
sc.recvuntil(b"c: ")
c = int(sc.recvline())

for l_m in range(54, 0, -1):
    random.setstate(state)
    l_n = len(long_to_bytes(n))
    padlength = l_n - l_m - 3
    pad = b"\x00\x02" + long_to_bytes(random.getrandbits(padlength * 8)).rjust(padlength, b"\x00") + b"\x00"
    f = (bytes_to_long(pad) * (2**(l_m * 8)) + x_pad)^3 - c

    rs = f.small_roots()
    if len(rs) > 0:
        print(long_to_bytes(rs[0]))

Flag

rarctf{but-th3y_t0ld_m3_th1s_p4dd1ng_w45-s3cur3!!}

Resources