TL;DR

  • フラグのビット列と提出したビット列の内積(但しmod 2)を得られる
  • 0と1を送った際の挙動の違いから、フラグのビットが無印の場合は確定で、Pt.2の場合は高確率で特定出来る

Writeup

次のようなスクリプトが動いている

import numpy as np
from os import urandom
import binascii
import hashlib
from secret import FLAG1, FLAG2

# Helpers

def digest_to_array(digest):
    hex_digest = binascii.hexlify(digest)
    binary = bin(int(hex_digest, base=16)).lstrip('0b')
    binary = np.array(list(binary))
    
    # currently leading 0s are gone, add them back
    missing_zeros = np.zeros(256 - len(binary))
    binary = np.concatenate((missing_zeros, binary.astype(int)))
    
    return binary.astype(int)

def bitstring_to_bytes(s):
    return int(s, 2).to_bytes((len(s) + 7) // 8, byteorder='big')

####################################################################################

def generate_hardcore(secret, r):
    return int(np.sum(secret * r) % 2)

def predictor(r, probability = 1):
    x_r = (r.copy() != digest_to_array(FLAG))
    np.random.seed(x_r)
    chance = np.random.rand()
    
    prediction = 0
    if chance <= probability:
        prediction = generate_hardcore(digest_to_array(FLAG), r)
    else:
        prediction = 1 - generate_hardcore(digest_to_array(FLAG), r)
        
    return int(prediction)

def parse_input():
    bitstring = input()
    assert len(bitstring) == 256
    assert set(bitstring) <= set(["0", "1"])
    
    bitstring = bitstring_to_bytes(bitstring)
    array = digest_to_array(bitstring) % 2
    return array

def Level(probability):
    hasher = hashlib.sha256()
    hasher.update(FLAG)
    encrypted_secret = hasher.hexdigest()
    problem_statement = \
        f"We're looking to find the secret behind this SHA1 hash <{str(encrypted_secret)}>. " \
         "Luckily for you, this socket takes a bitstring and predicts the dot product " \
        f"of the secret with that bit string (mod 2) with {int(100*probability)}% accuracy and sends you back the answer.\n"
        
    print(problem_statement)
    while True:
        array = parse_input()
        if array is None:
            return
        print(predictor(array, probability = probability))
    
def main():
    global FLAG
    diff = int(input("Select a difficulty (1/2):"))
    if diff == 1:
        FLAG = FLAG1
        Level(1)
    if diff == 2:
        FLAG = FLAG2
        Level(0.9)

if __name__ == "__main__":
    main()
        

problem_statementで書かれていることが全てで、与えたビット列とフラグのビット列の内積をとってmod 2したものをくれる。

また、無印とPt.2で同じコードが使われており、無印の場合はこの内積は確実に正しいものをくれるが、Pt.2は90%の確率で正しく、10%の確率で誤った結果を返す。

方針としては、1bitずつ当てていくことを考える。あるビットに注目して、それ以外のビットを固定して2通り試し、その際の挙動をオラクルとしてビットを当てる。

無印の場合は簡単で、フラグと提出結果に次のような対応がある(フラグのビットが0かつ提出時のビットが0の時の結果を$x \in {0,1}$とする)。

submit\flag01
0$x$$x$
1$x$$\lnot x$

注目しているビットにおいて、フラグのビットが0の時、提出したビットが何であれ、積は0となるから0を提出しても1を提出しても同じ結果が得られる。一方、フラグのビットが1の時は1を提出した際に内積が1増えることから、0を出した時と比べて結果は異なるはずである。よってこの2通りに差異があるかどうかでフラグのビットを特定出来る。

一方、Pt.2の場合は10%の確率で誤りが生じる。どちらも正しく判定する確率は$0.9 \times 0.9 = 0.81$であるから、8割の確率で無印の場合と同様に正しく判定出来るが、2割の確率で誤った結果が得られてしまう。

そこで試行回数を増やして対応する。今回は10回上記の判定法を試して最も多く判定されたビットをフラグとした。

また、predictor関数内の次に該当する箇所のせいで、毎度同じビット列を送るとchanceが固定されてしまい意味がないので、注目しているビット以外のビットは、10回どれもランダム生成したものを用いた。

    x_r = (r.copy() != digest_to_array(FLAG))
    np.random.seed(x_r)
    chance = np.random.rand()

なお、この判定方法でも256ビット全てを完璧に判定出来るわけでは無く、手元の実験では数ビットの誤りが発生する。そこで並列でコードを動かし、その結果を見比べて最終的に手作業でフラグを復元した。

ちなみに、フラグが通常の文字列であることから、各文字(バイト)のMSBに相当するビット(8の倍数番目のビット)は必ず0になる。下記コードでは試行回数/時間を抑えるためにそのバイトは必ず0であるという処理を加えている。

Code

※Pt.2のもの

import numpy as np
import binascii
import hashlib
import random
from pwn import remote
from Crypto.Util.number import long_to_bytes
from Crypto.Random.random import getrandbits

# Helpers

def digest_to_array(digest):
    hex_digest = binascii.hexlify(digest)
    binary = bin(int(hex_digest, base=16)).lstrip('0b')
    binary = np.array(list(binary))
    
    # currently leading 0s are gone, add them back
    missing_zeros = np.zeros(256 - len(binary))
    binary = np.concatenate((missing_zeros, binary.astype(int)))
    
    return binary.astype(int)

def bitstring_to_bytes(s):
    return int(s, 2).to_bytes((len(s) + 7) // 8, byteorder='big')


def get_rand_bitstrings(i):
    ret = ""
    for _ in range(i):
        ret = ret + str(getrandbits(1))

    return ret

on_local = False
if on_local:
    host = "localhost"
    port = 13337
    answer = b"xflagx{unkoburiburiunkomorimori}"
else:
    host = "ctf.b01lers.com"
    port = 9003
    answer = None

sc = remote(host, port)
sc.recvuntil(b"difficulty (1/2):")
sc.sendline("2".encode())
sc.recvuntil(b"hash <")
hex_hash = sc.recvuntil(b">")[:-1].decode()
sc.recvline()
sc.recvline()

flag = [0 for _ in range(256)]
for i in range(255):
    if i % 8 == 0:
        continue
    num_res = [0, 0]
    for _ in range(10):
        # create payload
        b1 = get_rand_bitstrings(i)
        b2 = get_rand_bitstrings(256 - i -1)
        # bit_vec[i] = 0
        bs = b1 + "0" + b2
        payload = digest_to_array(bitstring_to_bytes(bs))
        payload = "".join(map(str, payload))
        sc.sendline(payload.encode())
        res1 = int(sc.recvline())
        # bit_vec[i] = 1
        bs = b1 + "1" + b2
        payload = digest_to_array(bitstring_to_bytes(bs))
        payload = "".join(map(str, payload))
        sc.sendline(payload.encode())
        res2 = int(sc.recvline())
    
        if res1 == res2:
            num_res[0] += 1
        else:
            num_res[1] += 1

    print(i, num_res)
    if num_res[0] > num_res[1]:
        flag[i] = 0
    else:
        flag[i] = 1

print(flag)
flag_s = int("".join(map(str, flag)), 2)
print(long_to_bytes(flag_s))

if answer is not None:
    answer_bits = list(digest_to_array(answer))
    correct_cnt = 0
    for i in range(256):
        if answer_bits[i] == flag[i]:
            correct_cnt += 1

    print(correct_cnt)

Flag

bctf{goldreich-levin-theorem.:D}