TetCTF 2022 - fault
TL;DR
- RSAの復号オラクルが動いているが、秘密鍵の全てのbitが毎度異なる値でマスクされている
- 1bitのXORは四則演算に直して計算すると元の秘密鍵のbitを変数とした式が立つ
- 指数法則をいい感じに使うと行列の問題に落ちるので基底を入手し、ある1つのbit以外0であるような係数を用意すると、それを指数として利用することで、対応する秘密鍵のbitが0なら1が得られるような式になる
- これを判定条件として秘密鍵の各bitを特定して秘密鍵を復元する
- 以降、どのような指数が暗号文に掛けられたか分かるので互いに素なものを用意すれば拡張ユークリッド互除法を用いて暗号文の指数を1にして暗号文を得て、秘密鍵を指数としてべき乗すれば平文が手に入る
Prerequisite
- 1bit排他的論理和の四則演算表現
- 指数法則
- 線形代数
Writeup
次のようなスクリプトが動いている。
from secrets import randbits
from Crypto.Util.number import getPrime # pycryptodome
NBITS = 1024
D_NBITS = 128 # small `d` makes decryption faster
class Cipher:
def __init__(self):
p = getPrime(NBITS // 2)
q = getPrime(NBITS // 2)
self.n = p * q
self.d = getPrime(D_NBITS)
self.e = pow(self.d, -1, (p - 1) * (q - 1))
def encrypt(self, m: int) -> int:
assert m < self.n
return pow(m, self.e, self.n)
def faultily_decrypt(self, c: int):
assert c < self.n
fault_vector = randbits(D_NBITS)
return fault_vector, pow(c, self.d ^ fault_vector, self.n)
def main():
from secret import FLAG
cipher = Cipher()
c = cipher.encrypt(int.from_bytes(FLAG.encode(), "big"))
for _ in range(2022):
line = input()
print(cipher.faultily_decrypt(c if line == 'c' else int(line)))
if __name__ == '__main__':
main()
RSAの復号オラクルが動いている。但し、128bitの秘密鍵に対して、同じく128bitの数値でマスクした値を指数としている。具体的にはオラクルに送信した暗号文に対してとが返される。
ところでRSAが関連する問題なのにもも与えられない。これでは何も歯が立たない予感がするのでひとまずを特定する。
オラクルにを与えるとが返されるが、もしが奇数ならこれはになり、一方偶数なら1が返される。よって前者の場合を引いたらそれに1を足すことでを取得出来る。
オラクルからは2つの値が得られるが、復号出来ていない平文の方をとおく。ここで、とのibit目をとおくと次のようになる。
ここでとおいた。
ところで、1bitの排他的論理和について、次のような関係がある。
これを代入してあげると先程の式は次のようになる。
は既知なのでは移項して左辺に持っていくことが出来る。また、はまたはになり、これもどちらになるかはが既知なので判明する。よって、とを用いて表すことにする。これで次のような式が得られる。
復号オラクルには何回も問い合わせられるので、常に同じで何度も問い合わせて得られた回目の結果から計算されたをとおく。これに対して指数法則を用いると、回目のオラクルで計算されたをとして次が成り立つ。
右辺のの指数に関して線形性が成り立つことから、左辺の積は右辺の指数の和に、左辺のべき乗は右辺の指数のスカラー倍に対応する。よって各に対してを並べたベクトルを考え、適切に基底を選んで、スカラー倍と加算を繰り返して指数を生成すれば、にかぎらず、任意のに対してを計算できる予感がする。基底とするには128個の一次独立なベクトルで十分である。
正確にはのようなものを生成するのは難しい、というのも係数として有理数を許すなら可能だが、整数だけならそれは格子となり、の部分集合ではあるがに一致するような格子を得ることは難しい。上では整数乗は可能だが、有理数乗は難しいため、係数を求めたとしてもそれを左辺のの指数とすることは出来ない。
それでもとなる有理数係数が得られたのなら、任意の成分の分母を除去する数(任意成分の分母の最小公倍数)を掛けた整数係数によってが生成される。この例では右辺はとなるが、は0か1であり、ならこの値は1になるはずである。
よって、このような1成分(そこの添字をとする)だけ0以外の整数となるようなを生成する整数係数を選んで、左辺を計算し、その値が1かどうかを判定することでを特定出来る。これを全てのに対して行えばを復元出来る。
を入手出来たので、以後オラクルからを得た時、におけるを特定出来る。
ここでとおくと、もしが互いに素ならとなる整数を求めることが出来る。よって、となるので、これを乗すれば平文が手に入る。
オラクルでは文字列"c"を送信することで、フラグを平文とした暗号文(未知)が送信されるので、その結果でこの方法を用いればフラグが手に入る。
Code
from pwn import process
import sys
from Crypto.Util.number import long_to_bytes
def f_decrypt(c=None):
if c is not None:
sc.sendline(str(c).encode())
else:
sc.sendline("c".encode())
v, _m = eval(sc.recvline())
return v, _m
def v_to_svec(v):
ret = []
for _ in range(128):
if v & 1:
ret.append(-1)
else:
ret.append(1)
v >>= 1
return vector(ZZ, ret)
DEBUG = False
if len(sys.argv) > 1 and sys.argv[1] == "-d":
DEBUG = True
if DEBUG:
sc = process(["python3", "test_fault.py"])
else:
sc = process(["python3", "fault.py"])
if DEBUG:
sc.recvuntil(b"n = ")
_n = int(sc.recvline())
sc.recvuntil(b"d = ")
_d = int(sc.recvline())
while True:
v, _m = f_decrypt(-1)
if _m != 1:
n = _m + 1
break
if DEBUG:
assert n == _n
c = 2
_cs = []
for i in range(128):
_cs.append(power_mod(c, 2^i, n))
_cs = vector(ZZ, _cs)
M = []
_ms = []
for i in range(128):
v, _m = f_decrypt(c)
s = v_to_svec(v)
M.append(s)
for i in range(128):
if s[i] == -1:
_m *= power_mod(_cs[i], -1, n)
_m %= n
_ms.append(_m)
print("[+] decryptions are completed")
M = matrix(ZZ, M)
b_d = ""
for i in range(128):
target = [0 for j in range(i)] + [1] + [0 for j in range(128 - i - 1)]
target = vector(ZZ, target)
x = M.solve_left(target)
denoms = set()
for y in x:
denoms.add(y.denom())
max_denom = max(denoms)
x = max_denom * x
c_d = 1
for j in range(128):
c_d *= power_mod(_ms[j], int(x[j]), n)
c_d %= n
print(i, c_d)
if c_d != 1:
b_d = "1" + b_d
else:
b_d = "0" + b_d
d = int(b_d, 2)
if DEBUG:
assert d == _d
print(f"[+] d is recovered: {d}")
_ds = []
_ms = []
while True:
v, _m = f_decrypt()
_d = v ^^ d
_ds.append(_d)
_ms.append(_m)
for i, other_d in enumerate(_ds):
g, s, t = xgcd(other_d, _d)
if g == 1:
c = power_mod(_ms[i], s, n) * power_mod(_m, t, n) % n
print(long_to_bytes(power_mod(c, d, n)))
exit()
Flag
ローカルでやっただけ