序文 §

いつものupsolveのネタが尽きてきたので重めの問題(pbctf 2021 - Yet Another PRNG)をやろうとしたら、参考となっている論文が面白かったので紹介するついでにPoCを書きます

CMRGと問題設定 §

今回解析するRNGはCMRG("Combined Multiple Recursive Generators")というやつで、論文中では次のように状態xi,yix_i,y_iの更新と出力ziz_iが定義されている。

xi=a11xi1+a12xi2+a13xi3mod  m1yi=a21yi1+a22yi2+a23yi3mod  m2zi=xiyimod  m1 \begin{aligned} x_i &= a_{11}x_{i-1}+a_{12}x_{i-2}+a{13}x_{i-3} \mod m_1 \cr y_i &= a_{21}y_{i-1}+a_{22}y_{i-2}+a{23}y_{i-3} \mod m_2 \cr z_i &= x_i - y_i \mod m_1 \end{aligned}

問題設定として、a11,a12,a13,a21,a22,a23,m1,m2a_{11}, a_{12}, a_{13}, a_{21}, a_{22}, a_{23}, m_1, m_2は既知であり、ここからある程度(今回の方法による最低値は7個)の出力ziz_iを得てから、後続の出力を予測する。

CRT §

上記の通り、CMRGは実質的に2つの線形なRNGを組み合わせている。もし、m1,m2m_1, m_2が互いに素(大抵は周期を伸ばすためにm1,m2m_1, m_2が素数が使われるはずなので満たされる)であれば、中国人剰余定理を使って次のようなA,B,CA,B,Cを求めて2つのRNGを1つのものとして扱う事が出来る

Aa11mod  m1, Aa21mod  m2Ba21mod  m1, Aa22mod  m2Ca31mod  m1, Aa23mod  m2 \begin{aligned} A \equiv a_{11} \mod m_1, \ A\equiv a_{21} \mod m_2 \cr B \equiv a_{21} \mod m_1, \ A\equiv a_{22} \mod m_2 \cr C \equiv a_{31} \mod m_1, \ A\equiv a_{23} \mod m_2 \cr \end{aligned}

このA,B,CA,B,Cに対して、Xiximod  m1,Xiyimod  m2X_i \equiv x_i \mod m_1, X_i \equiv y_i \mod m_2とすれば次が成り立つ。

XiAXi1+BXi2+CXi3mod  m1m2 X_i \equiv AX_{i-1} + BX_{i-2} + CX_{i-3} \mod m_1m_2

格子に落とす式 §

上記で定義したXiX_iとそれに関する式に対してXiX_ixix_iで表すことが出来れば、m1m2m_1m_2を法としている式に対して、xim1m2x_i \approx \sqrt{m_1m_2}が現れ、これは式中の項の中で比較的小さい値となるから格子の問題に落とすことが出来る予感がする。というわけで試みる。

ここでzixiyiz_i' \coloneqq x_i - y_iと定義する。通常の出力ziz_iとの違いはm1m_1で法を取っているかいないかである。

出力の式を見るとzi=xiyimod  m1z_i = x_i- y_i \mod m_1となっているが、m1m_1m2m_2が近いと仮定すれば、m1<xiyi<m1-m_1 < x_i - y_i < m_1が成り立つので、zi=xiyizi=xiyi+m1z_i = x_i - y_i \lor z_i = x_i - y_i + m_1としてよい。よって、ziz_i'は出力ziz_iに対して2択であり、実際得た出力に対して全探索することでどこかで正しいziz_i'に当たる。よって、以下ではziz_i'を引き当てた場合を仮定する。

XiX_iの定義から、Xi=kim1+xi=kim2+yiX_i = k_i m_1 + x_i = k_i'm_2 + y_iとなる整数0ki<m2,0ki<m10\leq k_i <m_2, 0\leq k_i' < m_1が存在する。中辺と右辺を変形すると次が成り立つ

zi=xiyi=kim2kim1 z_i' = x_i - y_i = k_i'm_2 - k_im_1

更にm2m_2で法をとると、zikim1mod  m2z_i' \equiv -k_im_1\mod m_2が成り立つから、kik_iについて解くとkizim1mod  m2k_i \equiv -\frac{z_i'}{m_1} \mod m_2となる。論文に従って、um11mod  m2u \coloneqq m_1^{-1} \mod m_2とおいて、kium1mod  m2k_i \equiv -um_1 \mod m_2とする。

よって、このようなkik_iを用いれば次が成り立つ

ki+3m1+xi+3A(ki+2m1+xi+2)B(ki+1m1+xi+1)C(kim1+xi)0mod  m1m2 k_{i+3}m_1+x_{i+3} - A(k_{i+2}m_1 + x_{i+2}) - B(k_{i+1}m_1 + x_{i+1}) - C(k_im_1 + x_i) \equiv 0 \mod m_1m_2

この式中でxi+3,xi+2,xi+1,xix_{i+3}, x_{i+2}, x_{i+1}, x_i以外の項の大きさはm1m2m_1m_2と同程度であるが、これらはm1m2\sqrt{m_1m_2}程度であるから有意に小さい。よってLLL等の基底簡約アルゴリズムで短いベクトルを求めればその中にxi+3,xi+2,xi+1,xix_{i+3}, x_{i+2}, x_{i+1}, x_iが現れるような格子を構成出来る可能性が見えてくる。

LLLで倒す §

初期状態をx3,x2,x1,y3,y2,y1x_{-3}, x_{-2}, x_{-1}, y_{-3}, y_{-2}, y_{-1}として7つの出力z0,z1,z2,z3,z4,z5,z6z_0, z_1, z_2, z_3, z_4, z_5, z_6を得る。簡単のためそれぞれのziz_iに対してziz_i'を当てられたとする(よってkik_iは全て計算できる)。

この時、先程の式からx3x_3x0,x1,x2x_0, x_1, x_2を変数として表すことが出来る。具体的には次のようになる。

x3A(k2m1+x2)+B(k1m1+x1)+C(kim1+x0)k3m1mod  m1m2Ax2+Bx1+Cx0+D3mod  m1m2 \begin{aligned} x_3 &\equiv A(k_{2}m_1 + x_{2}) + B(k_{1}m_1 + x_{1}) + C(k_im_1 + x_0) - k_3m_1 \mod m_1m_2\cr &\equiv Ax_2 + Bx_1 + Cx_0 + D_3 \mod m_1m_2 \end{aligned}

ここで、D3D_3は定数項を全部足したものとする。また、DiD_iに関しては次のような関係がある。

Dikim1+Aki1m1+Bki2m1+Cki3m1mod  m1m2 D_{i} \equiv -k_{i}m_1 + Ak_{i-1}m_1 + Bk_{i-2}m_1 + Ck_{i-3}m_1 \mod m_1m_2

同様にしてx4x_4x3,x2,x1x_3, x_2, x_1で表すことが出来るが、x3x_3x2,x1,x0x_2,x_1,x_0で表すことが出来たのでx4x_4も同様である。

このようにしてx3,x4,x5,x6x_3, x_4, x_5,x_6はいずれもx0,x1,x2x_0, x_1, x_2の線形多項式で表すことが出来るので次のようにおく。

xici,0x0+ci,1x1+ci,2x2+Dimod  m1m2 x_i \equiv c_{i,0}x_0 + c_{i,1}x_1 + c_{i,2}x_2 + D_i \mod m_1m_2

m1m2m_1m_2を外して、商をlil_iとおくと次のようになる

xi=ci,0x0+ci,1x1+ci,2x2+Di+lim1m2 x_i = c_{i,0}x_0 + c_{i,1}x_1 + c_{i,2}x_2 + D_i + l_im_1m_2

先に述べたように、xix_iが短いベクトルの成分として現れるような格子を組んで簡約し、これらを求めることを考える。今回は次のような格子を組んだ。

(1c3,0c4,0c5,0c6,01c3,1c4,1c5,1c6,11c3,2c4,2c5,2c6,2232c3,3c4,3c5,3c6,3m1m2m1m2m1m2m1m2) \begin{pmatrix} 1 & & & & c_{3,0} & c_{4,0} & c_{5,0} & c_{6,0} \cr & 1 & & & c_{3,1} & c_{4,1} & c_{5,1} & c_{6,1} \cr & & 1 & & c_{3,2} & c_{4,2} & c_{5,2} & c_{6,2} \cr & & & 2^{32} & c_{3,3} & c_{4,3} & c_{5,3} & c_{6,3} \cr & & & & m_1m_2 & \cr & & & & & m_1m_2 & \cr & & & & & & m_1m_2 & \cr & & & & & & & m_1m_2 \cr \end{pmatrix}

これに左から(x0,x1,x2,1,l3,l4,l5,l6)(x_0, x_1, x_2, 1, l_3, l_4, l_5, l_6)を掛けると、(x0,x1,x2,232,x3,x4,x5,x6)(x_0, x_1, x_2, 2^{32}, x_3, x_4, x_5, x_6)が現れる。

この格子の体積は232(m1m2)42^{32}(m_1m_2)^4であるので、LLLで出てくる基底の大きさは(だいたい)24(m1m2)12242322^4(m_1m_2)^{\frac 12} \approx 2^4\cdot2^{32}より小さくなり、(x0,x1,x2,232,x3,x4,x5,x6)(x_0, x_1, x_2, 2^{32}, x_3, x_4, x_5, x_6)のノルムがだいたいこのぐらいなので出てくれると期待出来る。

x0,x1,x2x_0, x_1, x_2が求められれば、z0,z1,z2z_0', z_1', z_2'からy0,y1,y2y_0, y_1, y_2を求めることが出来るので以降の出力を完全に予測出来る。

係数と定数項を求める §

xix_iに対して、x0,x1,x2x_0, x_1, x_2の係数ci,0,ci,1,ci,2c_{i,0}, c_{i,1}, c_{i,2}と定数項DiD_iを手計算で求めようとすると骨が折れすぎるので次のスクリプトで求めた。

A = var("A")
B = var("B")
C = var("C")

Ds = [var(f"D{i}") for i in range(4)]
vars = [var(f"v{i}") for i in range(3)]

for i in range(3, 3 + 4):
    Ds.append(var(f"D{i}"))
    vars.append(A*vars[-1] + B*vars[-2] + C*vars[-3] + Ds[-1])

for i, v in enumerate(vars):
    if i < 3:
        continue
    c_v0 = v.list(vars[0])[1]
    c_v1 = v.list(vars[1])[1]
    c_v2 = v.list(vars[2])[1]
    constant = v.list(vars[0])[0].list(vars[1])[0].list(vars[2])[0]
    coeffs_dump = f"""{i}:
    v0_c = {c_v0}
    v1_c = {c_v1}
    v2_c = {c_v2}
    constant = {constant}"""

    print(coeffs_dump)

これを実行すると次のような結果になり、各係数が得られる

3:
    v0_c = C
    v1_c = B
    v2_c = A
    constant = D3
4:
    v0_c = A*C
    v1_c = A*B + C
    v2_c = A^2 + B
    constant = A*D3 + D4
5:
    v0_c = A^2*C + B*C
    v1_c = A^2*B + B^2 + A*C
    v2_c = A^3 + 2*A*B + C
    constant = A^2*D3 + B*D3 + A*D4 + D5
6:
    v0_c = A^3*C + 2*A*B*C + C^2
    v1_c = A^3*B + 2*A*B^2 + A^2*C + 2*B*C
    v2_c = A^4 + 3*A^2*B + B^2 + 2*A*C
    constant = A^3*D3 + 2*A*B*D3 + A^2*D4 + C*D3 + B*D4 + A*D5 + D6

余談だが、参考にした論文(参考文献に記載)では、これと似たような格子を組んで簡約しているようだが、論文に掲載されている格子は計算ミスで値が誤っているようである。

Code §

PRNG内で幾つかのチェックをしており見にくいが、最終的に7つのziz_i'と既知の値のみからx0,x1,x2x_0, x_1, x_2を復元している

# based on https://eprint.iacr.org/2021/1204.pdf

import random


# ref and parameter stolen from: http://www.secmem.org/blog/2021/10/24/Breaking-Combined-Multiple-Recursive-Generators/
class PRNG:
    def __init__(self) -> None:
        self.m1 = 2**32 - 107
        self.m2 = 2**32 - 5
        self.N = self.m1 * self.m2

        assert is_prime(self.m1)
        assert is_prime(self.m2)

        self.a1 = [random.getrandbits(32) for _ in range(3)]
        self.a2 = [random.getrandbits(32) for _ in range(3)]

        self.__x1 = [random.getrandbits(32) for _ in range(3)]
        self.__x2 = [random.getrandbits(32) for _ in range(3)]

        # debug parameters

        # CRT
        self.A = crt([self.a1[0], self.a2[0]], [self.m1, self.m2])
        self.B = crt([self.a1[1], self.a2[1]], [self.m1, self.m2])
        self.C = crt([self.a1[2], self.a2[2]], [self.m1, self.m2])

        self.Ds = []

        self.__answer1 = []
        self.__answer2 = []

        self.u = power_mod(self.m1, -1, self.m2)

        # z' in paper
        # self.__z = [x - y for x,y in zip(self.__x1, self.__x2)]
        # self.__k = [-z * self.u % self.m2 for z in self.__z]
        self.__z = []
        self.__k = []


    def next(self, i=None) -> int:
        new_x1 = sum([x * y for x, y in zip(reversed(self.a1), self.__x1)]) % self.m1
        new_x2 = sum([x * y for x, y in zip(reversed(self.a2), self.__x2)]) % self.m2


        new_z = new_x1 - new_x2
        new_k = -new_z * self.u % self.m2
        # t1 = new_k * self.m1 + new_x1
        # t2 = self.A * (self.__k[2] * self.m1 + self.__x1[2])
        # t3 = self.B * (self.__k[1] * self.m1 + self.__x1[1])
        # t4 = self.C * (self.__k[0] * self.m1 + self.__x1[0])
        # P = (t1 - t2 - t3 - t4) % (self.m1 * self.m2)

        if i is not None and i > 2:
            D = (-new_k + self.A * self.__k[-1] + self.B * self.__k[-2] +self.C * self.__k[-3]) * self.m1 % (self.N)
            self.Ds.append(D)
            P = (self.A * self.__x1[2] + self.B * self.__x1[1] + self.C * self.__x1[0] + D) % (self.N)
            assert P == new_x1

        # rough check
        if i in [3, 4, 5, 6]:
            v0, v1, v2 = self.__answer1[:3]
            A, B, C = self.A, self.B, self.C
            Ds = self.Ds
            if i == 3:
                v0_c = C
                v1_c = B
                v2_c = A
                constant = self.Ds[0]
            elif i == 4:
                v0_c = A * C
                v1_c = A * B + C
                v2_c = A**2 + B
                constant = A * self.Ds[0] + self.Ds[1]
            elif i == 5:
                v0_c = A^2*C + B*C
                v1_c = A^2*B + B^2 + A*C
                v2_c = A^3 + 2*A*B + C
                constant = A^2*Ds[0] + B*Ds[0] + A*Ds[1] + Ds[2]
            elif i == 6:
                v0_c = A^3*C + 2*A*B*C + C^2
                v1_c = A^3*B + 2*A*B^2 + A^2*C + 2*B*C
                v2_c = A^4 + 3*A^2*B + B^2 + 2*A*C
                constant = A^3*Ds[0] + 2*A*B*Ds[0] + A^2*Ds[1] + C*Ds[0] + B*Ds[1] + A*Ds[2] + Ds[3]

            rhs = (v0_c * v0 + v1_c * v1 + v2_c * v2 + constant) % self.N
            assert rhs == new_x1


        self.__z.append(new_z)
        self.__k.append(new_k)

        self.__answer1.append(new_x1)
        self.__answer2.append(new_x2)

        self.__x1 = self.__x1[1:] + [new_x1]
        self.__x2 = self.__x2[1:] + [new_x2]

        # assumption: perfect guess z'
        return new_z

        # return (new_x1 - new_x2) % self.m1

    def check(self, zs, ks, Ds):
        res1 = self.__z == zs
        res2 = self.__k == ks
        res3 = self.Ds == Ds

        return (res1, res2, res3)


    def get_answer(self):
        return self.__answer1, self.__answer2


# cheating (gueesing z' is success)
prng = PRNG()

# public parameters
m1 = prng.m1
m2 = prng.m2
N = m1 * m2
a1 = prng.a1
a2 = prng.a2
u = power_mod(m1, -1, m2)

print("=============== Exploit ===============")

A = crt([a1[0], a2[0]], [m1, m2])
B = crt([a1[1], a2[1]], [m1, m2])
C = crt([a1[2], a2[2]], [m1, m2])
Ds = []
ks = []
zs = []

for i in range(10):
    z = prng.next(i)
    k = -z * u % m2

    if i >= 3:
        D = (-k + A * ks[-1] + B * ks[-2] + C * ks[-3]) * m1 % N
        Ds.append(D)

    zs.append(z)
    ks.append(k)

# check
res = all(prng.check(zs, ks, Ds))
assert res

# calculation x0, x1, x2
size = 8
mat = [
    [0 for _ in range(size)] for _ in range(size)
]

for i in range(3):
    mat[i][i] = 1

mat[3][3] = 2^32

for i in range(4):
    mat[4+i][4+i] = N

# v0_c, v1_c, v2_c, constant
polys = [
    [C, B, A, Ds[0]],
    [A * C, A * B + C, A**2 + B, A * Ds[0] + Ds[1]],
    [A^2*C + B*C, A^2*B + B^2 + A*C, A^3 + 2*A*B + C, A^2*Ds[0] + B*Ds[0] + A*Ds[1] + Ds[2]],
    [A^3*C + 2*A*B*C + C^2, A^3*B + 2*A*B^2 + A^2*C + 2*B*C, A^4 + 3*A^2*B + B^2 + 2*A*C, A^3*Ds[0] + 2*A*B*Ds[0] + A^2*Ds[1] + C*Ds[0] + B*Ds[1] + A*Ds[2] + Ds[3]]
]

def dump_mat(m):
    for row in m:
        print(row)

for i, poly in enumerate(polys):
    for j, v in enumerate(poly):
        mat[j][i+4] = v % N

M = matrix(ZZ, mat)

for b in M.LLL():
    answer = []
    _ys = []
    if abs(b[3]) == 2**32:
        for x in b[:3]:
            answer.append(abs(x))

        for x, z in zip(answer, zs):
            _ys.append(x - z)

        x = sum([_x * _a for _x, _a in zip(answer, reversed(a1))]) % m1
        y = sum([_y * _b for _y, _b in zip(_ys, reversed(a2))]) % m2

        if x - y == zs[3]:
            print(answer, _ys)
            break


true_x, true_y = prng.get_answer()
print(f"[{answer == true_x[:3]}] answer: {true_x[:3]}")
print(f"[{_ys == true_y[:3]}] answer: {true_y[:3]}")

Reference §