Prerequisite

Writeup

次のスクリプトとその実行結果が与えられる

#!/usr/bin/env python3

from Crypto.Util.number import *
import random
import os
from flag import flag

def urand(b):
    return int.from_bytes(os.urandom(b), byteorder='big')

class PRNG:
    def __init__(self):
        self.m1 = 2 ** 32 - 107
        self.m2 = 2 ** 32 - 5
        self.m3 = 2 ** 32 - 209
        self.M = 2 ** 64 - 59

        rnd = random.Random(b'rbtree')

        self.a1 = [rnd.getrandbits(20) for _ in range(3)]
        self.a2 = [rnd.getrandbits(20) for _ in range(3)]
        self.a3 = [rnd.getrandbits(20) for _ in range(3)]

        self.x = [urand(4) for _ in range(3)]
        self.y = [urand(4) for _ in range(3)]
        self.z = [urand(4) for _ in range(3)]

    def out(self):
        o = (2 * self.m1 * self.x[0] - self.m3 * self.y[0] - self.m2 * self.z[0]) % self.M

        self.x = self.x[1:] + [sum(x * y for x, y in zip(self.x, self.a1)) % self.m1]
        self.y = self.y[1:] + [sum(x * y for x, y in zip(self.y, self.a2)) % self.m2]
        self.z = self.z[1:] + [sum(x * y for x, y in zip(self.z, self.a3)) % self.m3]

        return o.to_bytes(8, byteorder='big')

if __name__ == "__main__":
    prng = PRNG()

    hint = b''
    for i in range(12):
        hint += prng.out()
    
    print(hint.hex())

    assert len(flag) % 8 == 0
    stream = b''
    for i in range(len(flag) // 8):
        stream += prng.out()
    
    out = bytes([x ^ y for x, y in zip(flag, stream)])
    print(out.hex())
    

RNGの出力が12個与えられ、それに続く出力からstreamが構成され、これとフラグの排他的論理和をとったものcが与えられる

RNGの構造は前回書いた記事で扱ったCMRGと似ている。状態$x_i, y_i, z_i$の更新式と出力$o_i$は次の通り

$$ \begin{aligned} x_{i+3} &\equiv a_{11}x_{i} + a_{12}x_{i+1} + a_{13}x_{i+2} \mod m_1 \cr y_{i+3} &\equiv a_{11}y_{i} + a_{12}y_{i+1} + a_{13}y_{i+2} \mod m_2 \cr z_{i+3} &\equiv a_{11}z_{i} + a_{12}z_{i+1} + a_{13}z_{i+2} \mod m_3 \cr o_i &\equiv 2m_1x_i - m_3y_i - m_2z_i \end{aligned} $$

前の記事同様に中国人剰余定理を使って1つの式にまとめる為に$A,B,C$を次のように定義する

$$ \begin{aligned} A \equiv a_{11} \mod m_1, &\ A \equiv a_{21} \mod m_2, \ A \equiv a_{31} \mod m_3 \cr B \equiv a_{12} \mod m_1, &\ B \equiv a_{22} \mod m_2, \ B \equiv a_{32} \mod m_3 \cr C \equiv a_{13} \mod m_1, &\ C \equiv a_{23} \mod m_2, \ C \equiv a_{33} \mod m_3 \cr \end{aligned} $$

これに対して$X_i \equiv x_i \mod m_1, \ X_i \equiv y_i \mod m_2, \ X_i \equiv z_i \mod m_3$となる$X_i$を定義すると次が成り立つ

$$ X_{i+3} \equiv AX_i + BX_{i+1} + CX_{i+2} \mod m_1m_2m_3 $$

また、$X_i$の定義からある整数$k_i, k_i', k_i''$が存在して次が成り立つ

$$ \begin{aligned} X_i = k_im_1 + x_i \cr X_i = k_i'm_2+y_i \cr X_i = k_i''m_3+z_i \end{aligned} $$

ここで$o_i':=2m_1x_i - m_3y_i - m_2z_i$と定義する(mod無しの等号)。$o_i \equiv o_i' \mod M$であるが、$o_i-2M \leq o_i' \leq o_i+2M$であるので、$o_i'$は$o_i' = o_i + kN$の形で総当り出来る(適当な実験で$k =-2, -1, 0, 1$であることがわかった)。

また、$(2m_1 - m_3 - m_2)X_i = 2k_im_1^2 - k_i'm_2m_3 - k_i''m_2m_3 + o_i'$であり、$m_2m_3$で法をとることで次が成り立つ

$$ (2m_1 - m_3 - m_2)X_i \equiv 2k_im_1^2 + o_i' \mod m_2m_3 $$

更に今回の問題のパラメータは$2m_1 - m_3 - m_2 = 0$が成り立っているので$k_i$について解くと次のようになる

$$ k_i \equiv -\frac{o_i'}{2m_i^2} \mod m_2m_3 $$

$X_i = k_im_1 +x_i$を$X_{i+3} \equiv AX_i + BX_{i+1} + CX_{i+2} \mod m_1m_2m_3$に代入すると、未知数は$x_{i+3}, x_{i+2}, x_{i+1}, x_i$であるが、式のだいたいの項が96bitなのに対してこれらは32bitであり小さいので基底簡約で出てくる短いベクトルに現れる可能性がある。

ここで$x_4, x_5$に関して次が成り立つ

$$ \begin{aligned} x_4 &\equiv Ax_1 + Bx_2 + Cx_3 + D_4 \mod m_1m_2m_3 \cr X_5 &\equiv ACx_1 + (A+CB)x_2 + (B+C^2)x_3 + CD_4 + D_5 \mod m_1m_2m_3 \cr \end{aligned} $$

$D_i$は$X_i \equiv AX_{i-3} + Bx_{i-2} + Cx_{i-1} \mod m_1m_2m_3$に$X_j = k_jm_1 + x_j$を代入した時の定数項である。また、$x_5$に関しては$x_4$が現れるので、$x_4$に関する式を代入した。この式から法を外すと整数$l_4,l_5$を用いて次のようになる

$$ \begin{aligned} x_4 &= Ax_1 + Bx_2 + Cx_3 + D_4 + l_4m_1m_2m_3 \cr X_5 &= ACx_1 + (A+CB)x_2 + (B+C^2)x_3 + CD_4 + D_5 +l_5m_1m_2m_3 \cr \end{aligned} $$

この式を利用して次のような格子を組んだ

$$ \begin{pmatrix} 1 & & & & A & AC \cr & 1 & & & B & A+CB \cr & & 1 & & C & B+C^2 \cr & & & 2^{32} & D_4 & CD_4 + D_5 \cr & & & & m_1m_2m_3 & \cr & & & & & m_1m_2m_3 \end{pmatrix} $$

この格子に左から係数ベクトル$(x_1, x_2, x_3, 1, l_4, l_5)$を掛けると$(x_1, x_2, x_3, 2^{32}, x_4, x_5)$が現れる。LLLを使って簡約すると、格子の体積がだいたい$2^{32+96\times2} = 2^{224}$なので、出てくるベクトルの大きさはだいたい$2^{224/6} \approx 2^{38}$以下になってこのベクトルが出てきてくれる可能性がある。

以上より、$o_i$から$o_i'$を総当りし(これは$4^5$通りしか無いので可能)、そこから$k_i$を計算(つまり$D_i$も計算される)して上記のような格子を組んでから簡約すればどこかで$x_1, x_2, x_3$が入手出来る。

そこから$X_i = x_i + m_1k_i$を計算して$X_i$が手に入るので$m_2,m_3$で割ることで$y_i, z_i$も手に入り、RNGの出力を完全に予測出来る。よってstreamを再現してcと排他的論理和をとってフラグが手に入る。

Code

デバッグの跡で余分なコードが大量に残っている

#!/usr/bin/env python3

import itertools
import random
import os

def urand(b):
    return int.from_bytes(os.urandom(b), byteorder='big')

class PRNG:
    def __init__(self):
        self.m1 = 2 ** 32 - 107
        self.m2 = 2 ** 32 - 5
        self.m3 = 2 ** 32 - 209
        self.M = 2 ** 64 - 59
        self.m23 = self.m2 * self.m3
        self.m123 = self.m1 * self.m2 * self.m3
        self.u = power_mod(2 * self.m1**2, -1, self.m23)

        rnd = random.Random(b'rbtree')

        self.a1 = [rnd.getrandbits(20) for _ in range(3)]
        self.a2 = [rnd.getrandbits(20) for _ in range(3)]
        self.a3 = [rnd.getrandbits(20) for _ in range(3)]

        self.A=44560127569626536334684692547
        self.B=54178077656689068068903612461
        self.C=2714806752854611792965139512

        self.x = [urand(4) for _ in range(3)]
        self.y = [urand(4) for _ in range(3)]
        self.z = [urand(4) for _ in range(3)]

        self._x = []
        self.k = []
        self._o = []
        self.D = []


    def out(self):
        _o = 2 * self.m1 * self.x[0] - self.m3 * self.y[0] - self.m2 * self.z[0]
        self._o.append(_o)
        k = -self.u * _o % self.m23
        self.k.append(k)
        o = _o % self.M

        self._x.append(self.x[0])

        self.x = self.x[1:] + [sum(x * y for x, y in zip(self.x, self.a1)) % self.m1]
        self.y = self.y[1:] + [sum(x * y for x, y in zip(self.y, self.a2)) % self.m2]
        self.z = self.z[1:] + [sum(x * y for x, y in zip(self.z, self.a3)) % self.m3]


        return _o


    def check(self):
        if len(self.k) > 3:
            A, B, C = self.A, self.B, self.C
            m1 = self.m1
            ki3, ki2, ki1, ki = self.k[-4:]
            xi3, xi2, xi1, xi = self._x[-4:]
            D = -ki * m1 + A*ki3*m1 + B*ki2*m1 + C*ki1*m1
            self.D.append(D)
            rhs = A*xi3 + B*xi2 + C*xi1 + D
            # print(xi, rhs % self.m123)
            if len(self.k) == 4:
                rhs = A * self._x[0] + B * self._x[1] + C * self._x[2] + D
                # print(xi, rhs % self.m123)
            if len(self.k) == 5:
                D5 = self.D[-1]
                D4 = self.D[-2]
                rhs = A * C * self._x[0]
                rhs += (A + C*B) * self._x[1]
                rhs += (B+C**2) * self._x[2]
                rhs += C*D4 + D5
                # print(xi, rhs % self.m123)


def main():
    hint = "67f19d3da8af1480f39ac04f7e9134b2dc4ad094475b696224389c9ef29b8a2aff8933bd3fefa6e0d03827ab2816ba0fd9c0e2d73e01aa6f184acd9c58122616f9621fb8313a62efb27fb3d3aa385b89435630d0704f0dceec00fef703d54fca"
    c = "153ed807c00d585860b843a03871b11f60baf11fe72d2619283ec5b4d931435ac378e21abe67c47f7923fcde101f4f0c65b5ee48950820f9b26e33acf57868d5f0cbc2377a39a81918f8c20f61c71047c8e82b1c965fa01b58ad0569ce7521c7"

    hint = bytes.fromhex(hint)
    c = bytes.fromhex(c)

    orig_outs = [int.from_bytes(hint[8*i:8*i+8], "big") for i in range(len(hint) // 8)]

    print(orig_outs)

    M = 2 ** 64 - 59

    for qs in itertools.product([1, 0, -1, -2], repeat=5):
        outs = [o + q*M for o, q in zip(orig_outs, qs)]
        new_prng = solve(outs)

        if new_prng is not None:
            break

    stream = b''
    for i in range(len(c) // 8):
        stream += (int(new_prng.out() % M)).to_bytes(8, "big")
    
    out = bytes([x ^^ y for x, y in zip(c, stream)])
    print(out)

def solve(outs):
    prng = PRNG()

    a1 = prng.a1
    a2 = prng.a2
    a3 = prng.a3

    m1 = prng.m1
    m2 = prng.m2
    m3 = prng.m3
    m23 = m2 * m3
    m123 = m1 * m2 * m3
    M = prng.M

    u = power_mod(2 * m1**2, -1, m23)

    A=44560127569626536334684692547
    B=54178077656689068068903612461
    C=2714806752854611792965139512

    ks = []
    Ds = []

    for o in outs:
        k = -u * o % m23
        ks.append(k)

        if len(ks) > 3:
            ki3, ki2, ki1, ki = ks[-4:]
            D = -ki * m1 + A*ki3*m1 + B*ki2*m1 + C*ki1*m1
            D %= m123
            Ds.append(D)

    D4, D5 = Ds[0:2]

    size = 6
    mat = [
        [0 for _ in range(size)] for _ in range(size)
    ]

    mat[0][0] = 1
    mat[1][1] = 1
    mat[2][2] = 1
    mat[3][3] = 2**32
    mat[4][4] = m123
    mat[5][5] = m123

    mat[0][4] = A
    mat[1][4] = B
    mat[2][4] = C
    mat[3][4] = D4

    mat[0][5] = A*C % m123
    mat[1][5] = A + C*B % m123
    mat[2][5] = B+C**2 % m123
    mat[3][5] = C*D4 + D5 % m123

    mat = matrix(ZZ, mat)

    for b in mat.LLL():
        if abs(b[3]) == 2**32:
            x1, x2, x3 = list(map(abs, b[:3]))
            X1 = x1 + ks[0] * m1
            X2 = x2 + ks[1] * m1
            X3 = x3 + ks[2] * m1

            y1, y2, y3 = X1 % m2, X2 % m2, X3 % m2
            z1, z2, z3 = X1 % m3, X2 % m3, X3 % m3

            new_prng = PRNG()
            new_prng.x = [x1, x2, x3]
            new_prng.y = [y1, y2, y3]
            new_prng.z = [z1, z2, z3]

            _outs = [new_prng.out() for _ in range(12)]
            if _outs[:5] == outs:
                print(b)

                return new_prng

    return None


if __name__ == "__main__":
    main()

Flag

pbctf{Wow_how_did_you_solve_this?_I_thought_this_is_super_secure._Thank_you_for_solving_this!!!}

References