pbctf 2021 - Yet Another PRNG
Prerequisite
- CMRG
- 一昨日書いたこの記事を前提にして書くので読んでください
Writeup
次のスクリプトとその実行結果が与えられる
#!/usr/bin/env python3
from Crypto.Util.number import *
import random
import os
from flag import flag
def urand(b):
return int.from_bytes(os.urandom(b), byteorder='big')
class PRNG:
def __init__(self):
self.m1 = 2 ** 32 - 107
self.m2 = 2 ** 32 - 5
self.m3 = 2 ** 32 - 209
self.M = 2 ** 64 - 59
rnd = random.Random(b'rbtree')
self.a1 = [rnd.getrandbits(20) for _ in range(3)]
self.a2 = [rnd.getrandbits(20) for _ in range(3)]
self.a3 = [rnd.getrandbits(20) for _ in range(3)]
self.x = [urand(4) for _ in range(3)]
self.y = [urand(4) for _ in range(3)]
self.z = [urand(4) for _ in range(3)]
def out(self):
o = (2 * self.m1 * self.x[0] - self.m3 * self.y[0] - self.m2 * self.z[0]) % self.M
self.x = self.x[1:] + [sum(x * y for x, y in zip(self.x, self.a1)) % self.m1]
self.y = self.y[1:] + [sum(x * y for x, y in zip(self.y, self.a2)) % self.m2]
self.z = self.z[1:] + [sum(x * y for x, y in zip(self.z, self.a3)) % self.m3]
return o.to_bytes(8, byteorder='big')
if __name__ == "__main__":
prng = PRNG()
hint = b''
for i in range(12):
hint += prng.out()
print(hint.hex())
assert len(flag) % 8 == 0
stream = b''
for i in range(len(flag) // 8):
stream += prng.out()
out = bytes([x ^ y for x, y in zip(flag, stream)])
print(out.hex())
RNGの出力が12個与えられ、それに続く出力からstream
が構成され、これとフラグの排他的論理和をとったものc
が与えられる
RNGの構造は前回書いた記事で扱ったCMRGと似ている。状態$x_i, y_i, z_i$の更新式と出力$o_i$は次の通り
$$ \begin{aligned} x_{i+3} &\equiv a_{11}x_{i} + a_{12}x_{i+1} + a_{13}x_{i+2} \mod m_1 \cr y_{i+3} &\equiv a_{11}y_{i} + a_{12}y_{i+1} + a_{13}y_{i+2} \mod m_2 \cr z_{i+3} &\equiv a_{11}z_{i} + a_{12}z_{i+1} + a_{13}z_{i+2} \mod m_3 \cr o_i &\equiv 2m_1x_i - m_3y_i - m_2z_i \end{aligned} $$
前の記事同様に中国人剰余定理を使って1つの式にまとめる為に$A,B,C$を次のように定義する
$$ \begin{aligned} A \equiv a_{11} \mod m_1, &\ A \equiv a_{21} \mod m_2, \ A \equiv a_{31} \mod m_3 \cr B \equiv a_{12} \mod m_1, &\ B \equiv a_{22} \mod m_2, \ B \equiv a_{32} \mod m_3 \cr C \equiv a_{13} \mod m_1, &\ C \equiv a_{23} \mod m_2, \ C \equiv a_{33} \mod m_3 \cr \end{aligned} $$
これに対して$X_i \equiv x_i \mod m_1, \ X_i \equiv y_i \mod m_2, \ X_i \equiv z_i \mod m_3$となる$X_i$を定義すると次が成り立つ
$$ X_{i+3} \equiv AX_i + BX_{i+1} + CX_{i+2} \mod m_1m_2m_3 $$
また、$X_i$の定義からある整数$k_i, k_i', k_i''$が存在して次が成り立つ
$$ \begin{aligned} X_i = k_im_1 + x_i \cr X_i = k_i'm_2+y_i \cr X_i = k_i''m_3+z_i \end{aligned} $$
ここで$o_i':=2m_1x_i - m_3y_i - m_2z_i$と定義する(mod無しの等号)。$o_i \equiv o_i' \mod M$であるが、$o_i-2M \leq o_i' \leq o_i+2M$であるので、$o_i'$は$o_i' = o_i + kN$の形で総当り出来る(適当な実験で$k =-2, -1, 0, 1$であることがわかった)。
また、$(2m_1 - m_3 - m_2)X_i = 2k_im_1^2 - k_i'm_2m_3 - k_i''m_2m_3 + o_i'$であり、$m_2m_3$で法をとることで次が成り立つ
$$ (2m_1 - m_3 - m_2)X_i \equiv 2k_im_1^2 + o_i' \mod m_2m_3 $$
更に今回の問題のパラメータは$2m_1 - m_3 - m_2 = 0$が成り立っているので$k_i$について解くと次のようになる
$$ k_i \equiv -\frac{o_i'}{2m_i^2} \mod m_2m_3 $$
$X_i = k_im_1 +x_i$を$X_{i+3} \equiv AX_i + BX_{i+1} + CX_{i+2} \mod m_1m_2m_3$に代入すると、未知数は$x_{i+3}, x_{i+2}, x_{i+1}, x_i$であるが、式のだいたいの項が96bitなのに対してこれらは32bitであり小さいので基底簡約で出てくる短いベクトルに現れる可能性がある。
ここで$x_4, x_5$に関して次が成り立つ
$$ \begin{aligned} x_4 &\equiv Ax_1 + Bx_2 + Cx_3 + D_4 \mod m_1m_2m_3 \cr X_5 &\equiv ACx_1 + (A+CB)x_2 + (B+C^2)x_3 + CD_4 + D_5 \mod m_1m_2m_3 \cr \end{aligned} $$
$D_i$は$X_i \equiv AX_{i-3} + Bx_{i-2} + Cx_{i-1} \mod m_1m_2m_3$に$X_j = k_jm_1 + x_j$を代入した時の定数項である。また、$x_5$に関しては$x_4$が現れるので、$x_4$に関する式を代入した。この式から法を外すと整数$l_4,l_5$を用いて次のようになる
$$ \begin{aligned} x_4 &= Ax_1 + Bx_2 + Cx_3 + D_4 + l_4m_1m_2m_3 \cr X_5 &= ACx_1 + (A+CB)x_2 + (B+C^2)x_3 + CD_4 + D_5 +l_5m_1m_2m_3 \cr \end{aligned} $$
この式を利用して次のような格子を組んだ
$$ \begin{pmatrix} 1 & & & & A & AC \cr & 1 & & & B & A+CB \cr & & 1 & & C & B+C^2 \cr & & & 2^{32} & D_4 & CD_4 + D_5 \cr & & & & m_1m_2m_3 & \cr & & & & & m_1m_2m_3 \end{pmatrix} $$
この格子に左から係数ベクトル$(x_1, x_2, x_3, 1, l_4, l_5)$を掛けると$(x_1, x_2, x_3, 2^{32}, x_4, x_5)$が現れる。LLLを使って簡約すると、格子の体積がだいたい$2^{32+96\times2} = 2^{224}$なので、出てくるベクトルの大きさはだいたい$2^{224/6} \approx 2^{38}$以下になってこのベクトルが出てきてくれる可能性がある。
以上より、$o_i$から$o_i'$を総当りし(これは$4^5$通りしか無いので可能)、そこから$k_i$を計算(つまり$D_i$も計算される)して上記のような格子を組んでから簡約すればどこかで$x_1, x_2, x_3$が入手出来る。
そこから$X_i = x_i + m_1k_i$を計算して$X_i$が手に入るので$m_2,m_3$で割ることで$y_i, z_i$も手に入り、RNGの出力を完全に予測出来る。よってstream
を再現してc
と排他的論理和をとってフラグが手に入る。
Code
デバッグの跡で余分なコードが大量に残っている
#!/usr/bin/env python3
import itertools
import random
import os
def urand(b):
return int.from_bytes(os.urandom(b), byteorder='big')
class PRNG:
def __init__(self):
self.m1 = 2 ** 32 - 107
self.m2 = 2 ** 32 - 5
self.m3 = 2 ** 32 - 209
self.M = 2 ** 64 - 59
self.m23 = self.m2 * self.m3
self.m123 = self.m1 * self.m2 * self.m3
self.u = power_mod(2 * self.m1**2, -1, self.m23)
rnd = random.Random(b'rbtree')
self.a1 = [rnd.getrandbits(20) for _ in range(3)]
self.a2 = [rnd.getrandbits(20) for _ in range(3)]
self.a3 = [rnd.getrandbits(20) for _ in range(3)]
self.A=44560127569626536334684692547
self.B=54178077656689068068903612461
self.C=2714806752854611792965139512
self.x = [urand(4) for _ in range(3)]
self.y = [urand(4) for _ in range(3)]
self.z = [urand(4) for _ in range(3)]
self._x = []
self.k = []
self._o = []
self.D = []
def out(self):
_o = 2 * self.m1 * self.x[0] - self.m3 * self.y[0] - self.m2 * self.z[0]
self._o.append(_o)
k = -self.u * _o % self.m23
self.k.append(k)
o = _o % self.M
self._x.append(self.x[0])
self.x = self.x[1:] + [sum(x * y for x, y in zip(self.x, self.a1)) % self.m1]
self.y = self.y[1:] + [sum(x * y for x, y in zip(self.y, self.a2)) % self.m2]
self.z = self.z[1:] + [sum(x * y for x, y in zip(self.z, self.a3)) % self.m3]
return _o
def check(self):
if len(self.k) > 3:
A, B, C = self.A, self.B, self.C
m1 = self.m1
ki3, ki2, ki1, ki = self.k[-4:]
xi3, xi2, xi1, xi = self._x[-4:]
D = -ki * m1 + A*ki3*m1 + B*ki2*m1 + C*ki1*m1
self.D.append(D)
rhs = A*xi3 + B*xi2 + C*xi1 + D
# print(xi, rhs % self.m123)
if len(self.k) == 4:
rhs = A * self._x[0] + B * self._x[1] + C * self._x[2] + D
# print(xi, rhs % self.m123)
if len(self.k) == 5:
D5 = self.D[-1]
D4 = self.D[-2]
rhs = A * C * self._x[0]
rhs += (A + C*B) * self._x[1]
rhs += (B+C**2) * self._x[2]
rhs += C*D4 + D5
# print(xi, rhs % self.m123)
def main():
hint = "67f19d3da8af1480f39ac04f7e9134b2dc4ad094475b696224389c9ef29b8a2aff8933bd3fefa6e0d03827ab2816ba0fd9c0e2d73e01aa6f184acd9c58122616f9621fb8313a62efb27fb3d3aa385b89435630d0704f0dceec00fef703d54fca"
c = "153ed807c00d585860b843a03871b11f60baf11fe72d2619283ec5b4d931435ac378e21abe67c47f7923fcde101f4f0c65b5ee48950820f9b26e33acf57868d5f0cbc2377a39a81918f8c20f61c71047c8e82b1c965fa01b58ad0569ce7521c7"
hint = bytes.fromhex(hint)
c = bytes.fromhex(c)
orig_outs = [int.from_bytes(hint[8*i:8*i+8], "big") for i in range(len(hint) // 8)]
print(orig_outs)
M = 2 ** 64 - 59
for qs in itertools.product([1, 0, -1, -2], repeat=5):
outs = [o + q*M for o, q in zip(orig_outs, qs)]
new_prng = solve(outs)
if new_prng is not None:
break
stream = b''
for i in range(len(c) // 8):
stream += (int(new_prng.out() % M)).to_bytes(8, "big")
out = bytes([x ^^ y for x, y in zip(c, stream)])
print(out)
def solve(outs):
prng = PRNG()
a1 = prng.a1
a2 = prng.a2
a3 = prng.a3
m1 = prng.m1
m2 = prng.m2
m3 = prng.m3
m23 = m2 * m3
m123 = m1 * m2 * m3
M = prng.M
u = power_mod(2 * m1**2, -1, m23)
A=44560127569626536334684692547
B=54178077656689068068903612461
C=2714806752854611792965139512
ks = []
Ds = []
for o in outs:
k = -u * o % m23
ks.append(k)
if len(ks) > 3:
ki3, ki2, ki1, ki = ks[-4:]
D = -ki * m1 + A*ki3*m1 + B*ki2*m1 + C*ki1*m1
D %= m123
Ds.append(D)
D4, D5 = Ds[0:2]
size = 6
mat = [
[0 for _ in range(size)] for _ in range(size)
]
mat[0][0] = 1
mat[1][1] = 1
mat[2][2] = 1
mat[3][3] = 2**32
mat[4][4] = m123
mat[5][5] = m123
mat[0][4] = A
mat[1][4] = B
mat[2][4] = C
mat[3][4] = D4
mat[0][5] = A*C % m123
mat[1][5] = A + C*B % m123
mat[2][5] = B+C**2 % m123
mat[3][5] = C*D4 + D5 % m123
mat = matrix(ZZ, mat)
for b in mat.LLL():
if abs(b[3]) == 2**32:
x1, x2, x3 = list(map(abs, b[:3]))
X1 = x1 + ks[0] * m1
X2 = x2 + ks[1] * m1
X3 = x3 + ks[2] * m1
y1, y2, y3 = X1 % m2, X2 % m2, X3 % m2
z1, z2, z3 = X1 % m3, X2 % m3, X3 % m3
new_prng = PRNG()
new_prng.x = [x1, x2, x3]
new_prng.y = [y1, y2, y3]
new_prng.z = [z1, z2, z3]
_outs = [new_prng.out() for _ in range(12)]
if _outs[:5] == outs:
print(b)
return new_prng
return None
if __name__ == "__main__":
main()
Flag
pbctf{Wow_how_did_you_solve_this?_I_thought_this_is_super_secure._Thank_you_for_solving_this!!!}