TL;DR

  • LFSRを利用した乱数生成器が与えられる
  • 閾値を与えて、その値より小さい数が生成された時に「その数を生成するのに要した回数」のみが与えられる
  • 閾値を大きくなるように設定すれば生成が失敗した時の上位bitが1だと判断出来るのでこれを利用してLFSRの出力とstep数の対応関係を複数集め、線形方程式を解いて初期状態を復元する

Prerequisite

  • LFSR
  • 線形代数

Writeup

次のようなスクリプトが動いている

#!/usr/local/bin/python

import secrets

class LFSR:
    def __init__(self, key, taps):
        self._s = key
        self._t = taps            

    def _sum(self, L):
        s = 0
        for x in L:
            s ^= x
        return s

    def _clock(self):
        b = self._s[0]
        self._s = self._s[1:] + [self._sum(self._s[p] for p in self._t)]
        return b

    def bit(self):
        return self._clock()

class RNG:
    def __init__(self, lfsr, N, nbits):
        self.lfsr = lfsr
        self.N = N
        self.nbits = nbits
        
        if not (pow(2, 27) < N < pow(2, 31)):
            raise ValueError("modulus is too big or small")
        
        K = pow(2, nbits) // N
        self.cutoff = K * N

    def get_random_nbit_integer(self):
        res = 0
        for i in range(self.nbits):
            res += self.lfsr.bit() << i
        return res
    
    def get_random_integer_modulo_N(self):
        count = 1
        while True:
            x = self.get_random_nbit_integer()
            if x < self.cutoff:
                return x % self.N, count
            count += 1

taps = [60, 58, 54, 52, 48, 47, 45, 43, 38, 36, 32, 28, 22, 21, 13, 9, 8, 5, 2, 0]
n = 64

with open("flag.txt", "r") as f:
    flag = f.read()

if __name__ == "__main__":
    print("Welcome to the unbiased random number factory!")
    N = int(input("What modulus would you like to use? Choose between 2^27 and 2^31: "))

    key = secrets.randbits(n)
    key_bits = [(key >> i)&1 for i in range(n)]
    
    lfsr = LFSR(key_bits, taps)
    rng = RNG(lfsr, N, 32)
    
    for _ in range(1024):
        c = input("Enter your command (R,F): ")
        if c.startswith("R"):
            x,t = rng.get_random_integer_modulo_N()
            print("creating this random number took {} attempts".format(t))
        elif c.startswith("F"):
            seed = int(input("what was my seed?"))
            if seed == key:
                print(flag)
            exit(0)
        else:
            print("unsupported command")

LFSRを使った乱数生成器が動いている。LFSRが何であるかは一番下にも載せたy011d4さんの資料を参考にしていただくとして、出力も状態の先頭のbitを出力しているだけであることから、十分な量のstep数とその出力の対応を得ることが出来れば初期状態を復元出来る。

具体的には、初期状態を$\mathbb F_2$のベクトル$\boldsymbol s$とおいた時に、$k$回出力した後の状態は次のような$\mathbb F_2$上の行列$M$を用いて$M^k \boldsymbol s$で表されることから、出力は$M^k$の一番上にある行ベクトル(以下、$M^k_0$とおく)と初期状態の内積となる。

$$ M = \begin{pmatrix} 0 & 1 & & \cr 0 & 0& 1 & \cr \vdots & & &\ddots & \cr 0& & && 1 \cr t_0 & t_1 & \cdots & & t_{n-1} \end{pmatrix} $$

ここで、$t_i \in \{0,1\}$はタップに対応する値である。

このLFSRの出力ビットを32個並べて32bitの数値としたものが乱数として扱われる。

また、この問題では最初に$N \ (2^{27} \lt N \lt 2^{31})$という値を提出する必要がある。この値は次のようなcutoff変数を定義する為に使われる(この変数を$c$とおく)

$$ \begin{aligned} K &= \left\lfloor\frac{2^{32}}{N}\right\rfloor \cr c &= K \cdot N = \left\lfloor\frac{2^{32}}{N}\right\rfloor \cdot N \end{aligned} $$

この$c$に対してLFSRで作った乱数$x$が$x < c$を満たしていれば、乱数生成器の出力となり、そうでなければそうなるまで繰り返す。問題では、この手順が何回繰り返されたかだけが開示される。

問題を解く上で重要な点として、乱数生成時に$c$以上の数$x$が得られた場合は最上位bitが特定出来るということである。LFSRの定義より$N$に対して次が成り立っている。

$$ \frac{2^{32}}N - \frac 12 \leq \left\lfloor \frac{2^{32}}N \right\rfloor \leq \frac{2^{32}}N $$

よって、両辺$N$を掛けて真ん中に$c$を持ってくると次が成り立つ。

$$ 2^{31} \leq 2^{32} - \frac N2 \leq \left\lfloor \frac{2^{32}}N \right\rfloor\cdot N = c \leq 2^{32} $$

一番左の不等式は$N$の範囲から導かれ、したがって$2^{31} \leq c \leq 2^{32}$が成り立つ。よってもし、$c \leq x$となる$x$が生成された場合は$2^{31} \leq x$であり、つまり$x$の最上位bitは1である。

この問題では、乱数生成に要した回数、つまり$c \leq x$になった回数+1を教えてくれる。1回の乱数生成では32個のLFSRの出力が用いられることから、失敗も含めて$n$回の乱数生成が行われた時、その時生成された数の最上位bitは$(n-1) \times 32 + 31$個目のLFSRの出力となる(下記ソルバのコードとの対応上、0-indexedとした)。

以上より、もし$c \leq x$にならなかった回数が1以上であれば、そこで作られた数の最上位bitに対応するLFSRの出力回数とその出力(1になる)が判明する。これはつまり、$\langle M^k_0, \boldsymbol s\rangle = 1$であるため、この関係を十分な量集めれば、連立線形方程式を解くことで$\boldsymbol s$を求めることが出来る(具体的には$M^k_0$を縦に並べた行列に右から$\boldsymbol s$を掛けた時に出力値を並べたベクトルになるような$\boldsymbol s$を求める)。

問題は$c$としてどのような値を設定するかということで、$c \leq x$になる確率がある程度大きくなくてはならない。これは$\left\lfloor \frac{2^{32}}N \right\rfloor$を小さくすれば良く、色々試していたら$N = 2^{30} + 2^{29}$で良い結果が得られた。この時$c=2^{31} + 2^{30}$となるようで、これなら(乱数の分布が一様ランダムだと仮定して)1/4の確率で$c \leq x$となる数$x$が得られる。

最後の問題として、出力が1になるような$M^k_0$だけを集めても次元が64にならない(63にはなる)ことがある。これでは初期状態を求めることが出来ないが、以降のある乱数の最上位ビットが0であると仮定して線形方程式を組み立てれば、「実際にそこが0である確率」と同じ確率で初期状態を求めることが出来る

一応、成功確率を挙げたり確実に解けるようにするために次のようなことを試した

  • もし1回の生成で乱数が得られたら、2/3の確率で最上位bitが0なのでこの場合を用いた
  • 線形方程式を解くのに用いる行列のランクが64になるような場合のみ選んだ

Code

from pwn import process
import sys


K = GF(2)

taps = [60, 58, 54, 52, 48, 47, 45, 43, 38, 36, 32, 28, 22, 21, 13, 9, 8, 5, 2, 0]
n = 64

DEBUG = len(sys.argv) > 1 and sys.argv[1] == "-d"

# exploit

# 1. preparation
N = 2**30 + 2**29
assert (pow(2, 27) < N < pow(2, 31))
k = pow(2, 32) // N
cutoff = k * N

m = [
    [0 for _ in range(64)] for _ in range(63)
]

for i in range(63):
    m[i][i+1] = 1

ts = [0 for _ in range(64)]
for t in taps:
    ts[t] = 1

m.append(ts)

m = matrix(K, m)
m_pows = []
_m = m^0

for k in range(1024 * 32):
    m_pows.append(_m[0])
    _m *= m


# 2. connect
if DEBUG:
    sc = process(["python3", "./debug_rejected.py"])
else:
    sc = process(["python3", "./rejected.py"])

sc.recvuntil(b"2^31: ")
sc.sendline(str(N).encode())

answer = None

if DEBUG:
    answer = eval(sc.recvline())

known = []
current_round = 0

lhs_matrix = None

while True:
    sc.recvuntil(b"command (R,F): ")
    sc.sendline(b"R")
    sc.recvuntil(b" number took ")
    t = int(sc.recvuntil(b" "))
    idx = current_round * 32 + 31

    current_round += t
    if t > 1:
        known.append((idx, 1))
        if lhs_matrix is None:
            lhs_matrix = matrix(K, m_pows[idx])
        else:
            lhs_matrix = lhs_matrix.stack(m_pows[idx])

        if lhs_matrix.rank() == 63:
            break

while True:
    sc.recvuntil(b"command (R,F): ")
    sc.sendline(b"R")
    sc.recvuntil(b" number took ")
    t = int(sc.recvuntil(b" "))
    idx = current_round * 32 + 31

    current_round += t
    if t == 1:
        if lhs_matrix.stack(m_pows[idx]).rank() == 64:
            v = [1 for _ in range(lhs_matrix.nrows())] + [0]
            v = vector(K, v)
            state = lhs_matrix.stack(m_pows[idx]).solve_right(v)
            print(state)
            if answer is not None:
                print(answer)
            break

recovered_key = 0
for i, b in enumerate(state):
    recovered_key += (int(b) << i)

sc.recvuntil(b"command (R,F): ")
sc.sendline(b"F")
sc.recvuntil(b"seed?")
sc.sendline(str(recovered_key).encode())
try:
    print(sc.recvline())
except:
    print("damedayo")

Flag

ローカルでやっただけだが、フラグがある問題リポジトリ上で解いたので表示された

dice{so-many-numbers-got-rejected-on-valentines-day-1cc16ff5b20d6be1fbd65de0d234608c}

Resources