A simple Hastad attack with e=65537... What can go wrong?
28 October, 2021
2548 words
13 mins

This challenge is about a simple server send out RSA-encrypted data of a flag. Here we are going to use a classic exploit of Hastad's attack, but with a small twist of optimising code base :)

We have a hidden flag value M, then is encrypted to become C by:

C=M65537modNC = M^{65537} \mod N

where N = pq, where p and q are two big (512-bit) primes and N > M. This makes N 1024-bit. The server generates these C and N, then send them to us, the client as many times as we like, each time a different C and a different N > M composed by entirely different ps and qs.


IDEA

Notice M does not change all the time, nor it is being randomly padded. We could use the Hastad's attack by getting 65537 different Cs, (C1, C2, ...) and 65537 different Ns (N1, N2, ...) and recover the actual M65537 by:

M65537modlcm(N1,N2,...)=crt([C1,C2,C3,...,C65537],[N1,N2,N3,...,N65537])M^{65537} \mod lcm(N_1, N_2, ...) \\ = crt([C_1, C_2, C_3, ..., C_{65537}], [N_1, N_2, N_3, ..., N_{65537}])

If we consider each of the N composed by different prime numbers, then lcm(N1, N2, ...) = N1N2...N65537 so

M65537modlcm(N1,N2,...)=M65537modN1N2...N65537M^{65537} \mod lcm(N_1, N_2, ...) = M^{65537} \mod N_1N_2...N_{65537}

Since

M<Nifor every i[1,65537]M65537<N1N2...N65537M65537modN1N2...N65537=M65537M < N_i \\\text{for every i} \in [1, 65537] \\ \rightarrow M^{65537} < N_1N_2...N_{65537} \\ \rightarrow M^{65537} \mod N_1N_2...N_{65537} = M^{65537}

To recover M, all we need is to take the 65537-th root of M65537, which can be done very fast using Sagemath, and we're done!


Collect C-N Pairs

Collecting 65537 Cs and Ns is doable. Since for each request the server sends us 100 C-N pairs, we just need to send (65537//100 + 1) = 656 requests in total. Took a while (about an hour), but definitely feasible.

from pwn import remote from tqdm import tqdm, trange ##### socket ##### host = '34.136.227.174' port = int(25455) def recving(): ns = [] cs = [] io = remote(host, port) for _ in range(100): data = io.recvline().split(b':::') ns.append(int(data[0])) cs.append(int(data[-1])) io.close() return cs, ns ##### main ##### # Get results cs, ns = [], [] for _ in trange(65537 // 100 + 1): result = recving() cs += result[0] ns += result[1] open('c.py', 'w').write(f'cs = {cs}') open('n.py', 'w').write(f'ns = {ns}')

After we have our values C and N in two files: c.py and n.py.

From that, we just need to solve this by:

from sage.all import * from Crypto.Util.number import * from c import c from n import n m_65537 = crt(c, n) print(long_to_bytes(Integer(m_65537).nth_root(65537)))

Reality is Harsh

So we waited...

One hour... Two hours... Three hours... (actually I give up after 10 mins, this is just for dramatic purposes, but my friend Timmy run for another 2 hours :V he could confirm)

And it did not give us any output.

Why did it take so long???


Sage's CRT is Slow

Here's the snippet of code from Sage's implementation of crt(), function CRT_list() that handles when the inputs to crt() are lists.

### Snippet from CRT_list() x = v[0] m = moduli[0] from sage.arith.functions import lcm for i in range(1, len(v)): x = CRT(x,v[i],m,moduli[i]) m = lcm(m,moduli[i]) return x % m

Basically, what this function does is it builds up the calculation like so:

C: list of rems, N: list or modsx1C1 ; m1N1loopxiCRT([xi1,Ci],[mi1,Ni])milcm(mi1,Ni)for i[2,len(C)]\text{C: list of rems, N: list or mods}\\ \\ x_1 \leftarrow C_1 \text{ ; } m_1 \leftarrow N_1 \\ \\ \text{loop}\\ x_i \leftarrow CRT([x_{i-1}, C_i], [m_{i-1}, N_i]) \\ m_i \leftarrow lcm(m_{i-1}, N_i) \\ \\ \text{for i} \in [2, len(C)]

The result of the algorithm is xlen(C).

The algorithm works fine, it gets the job done. The problem lies when the number of values inputed at CRT IS BIG.


To demonstrate what I mean, let's dig in to see what happens in each loop of that CRT_list() function.

x0 has 1024 bits. Let's replace this number of bits with b and assume that the time taken to write data to a b-bit number is Tb, and it scales up linearly with b.


After each loop, we need to create a new variable xi with +b bits more than xi-1 and write data to it. So in time, the number of bits of xi will be like this:

[b,2b,3b,4b,...,65537b][b, 2b, 3b, 4b, ..., 65537b]

The total time taken to generate variables and write data will be like this:

Ttotal=i=165537Tib=i=165537iTb=Tb(65537)(655371)2=2147516416TbT_{total} = \sum_{i=1}^{65537}T_{ib} = \sum_{i=1}^{65537}iT_b = T_b\frac{(65537)*(65537 - 1)}{2} = 2147516416T_b

As we can see, for b = 1024, the time taken to write one number is neglectable, but 2 billions of similar writes would make a HUGE different.


Let's analyse the runtime in O-notation. Replace 65537, the amount of numbers used in CRT, with k we have:

TtotalTbk22=O(k2)T_{total} \approx \frac{T_bk^2}{2} = O(k^2)

Not a very good runtime to be honest ✌️


And I haven't mention the runtime of the multiply algorithm between 2 numbers, which has a more dominant runtime. It has O(blog23) with respect to the number of bits, and is used a lot during CRT. Here I will stick to the write-data algorithm as it is easier to elaborate 😗


From O(k2) to O(klogk)

Instead of doing CRT sequentially from left to right each loop like Sage's algorithm, what we did in each loop is doing CRT for EACH 2 NUMBERS in the array. After k/2 runs we create 2 new list of k/2 2b-bit remainders and modulus. Then we do the same thing again, yield 2 new list of k/4 4b-bit remainders and modulus... Until the list of remainder has only one element 😗


At the first k/2 runs, the runtime for write data to k/2 2b-bit number is:

T2bk2=2Tbk2=kTbT_{2b} * \frac{k}{2} = 2T_{b} * \frac{k}{2} = kT_b

The next n/4 run yields the runtime of:

T4bk4=4Tbk4=kTbT_{4b} * \frac{k}{4} = 4T_b * \frac{k}{4} = kT_b

And so on...


Since the number of runs divides by 2 each loop, we have in total log(k) loops, which yield the total runtime for creating and writing memory of:

Ttotal=kTblog(k)=O(klog(k))T_{total} = kT_b\log(k) = O(k\log(k))

Comparing to the previous O(k2) algorithm, we have achieved a speed up of about (for k=65537):

k2/log(k)=655372log(65537)2048 times\frac{k}{2} / log(k) = \frac{65537}{2 * \log(65537)} \approx 2048 \text{ times}

Again, this is only the speed up ratio for creating and writing new memory, for speeding up in multiplication and other more time-consuming activities in crt(), it's different, but I guess we could just expect some value around this.


With that in mind, let's implement the code (btw for some reason the algorithm runs faster when we consider in each loop doing CRT for every 8 NUMBERS instead of 2):

from sage.all import * from Crypto.Util.number import * from c import c from n import n c = c[:65537] n = n[:65537] """ /function/ CRT_(): "" Purpose: Function that calculates CRT on chunks of <SEG_SIZE> numbers in the array rather than the whole array at once. Works pretty nice with big array and (probably) big number. Built on the base of Sage's crt() function. "" Args: r: List of remainders. m: List of modulus. SEG_SIZE=8: Number of values that CRT should work on once at a time. debug=False: Print out some debug data. """ def CRT_(r, m, SEG_SIZE=8, debug=False): if debug: print(f'[ i ] Calculate CRT with chunk size {SEG_SIZE}...') print(f'[ i ] Start loop with len = {len(r)}') while len(r) != 1: newR = [] newM = [] for i in range(0, len(r), SEG_SIZE): if len(r) - i == 1: newR.append(r[i]) newM.append(m[i]) else: crt_ = crt(r[i:i+SEG_SIZE], m[i:i+SEG_SIZE]) prod = 1 for _m in m[i:i+SEG_SIZE]: prod *= _m newR.append(crt_) newM.append(prod) r = newR m = newM if debug: print(f'[ i ] Update loop with len = {len(r)}') if debug: print(f'[ i ] Finished :D') return r[0] # Got CRT # m_65537 = crt(c, n) # <- too slow ! m_65537 = CRT_(c, n, debug=True) print(long_to_bytes(Integer(m_65537).nth_root(65537)))

And the algorithm only took us 3 minutes to run, which is definitely faster than infinity :)

Even Better: From 3 to 1 minute (update)

It's great that my writeup got so many praises from the author, cothan who is also a long-time Crypto player 😃 I'm crying in happiness right now TwT. Although I came up with the algorithm, but I got stuck and gave up half-way through. It was Timmy who stick to my idea and implement it to solve it (thank you very^n where n goes to infinity much) and get juicy points 😗 And he also noticed for somehow, grouping 8 crt()s together seems to be a better choice than 2(?)

cothan also mentioned that applying parallelism to the algorithm is a very good way to improve the algorithm.

Now, I'm not familiar with Julia much, but parallelism is do-able with so much ease. The key here is noticing that in each loop, each result of crt() calls does not affect each other until the next loop. This implies that we can have multiple calls to crt() running simultaneously in a loop, which will improve a lot our performance. If it's coded in Julia instead of Python, it would have been better. For now, let's stick to Python. With a few twitches to the code, here's what we've got:

from sage.all import * from Crypto.Util.number import * import concurrent.futures from c import c from n import n c = c[:65537] n = n[:65537] """ /function/ CRT_(): "" Purpose: Function that calculates CRT on chunks of <SEG_SIZE> numbers in the array rather than the whole array at once. Works pretty nice with big array and (probably) big number. Test: 65537 1024-bit numbers on a single core -> 3 mins 15 seconds. Test: 65537 1024-bit numbers on 8 cores -> 1 mins 2 seconds. Built on the base of Sage's crt() function. "" Args: r: List of remainders. m: List of modulus. SEG_SIZE=12: Number of values that CRT should work on once at a time. debug=False: Print out some debug data. PROCESS_NO=8: Number of processes to run at the same time. """ def CRT_(r, m, SEG_SIZE=12, debug=False, PROCESS_NO=8): assert len(r) == len(m) if debug: print(f'[ i ] Calculate CRT with chunk size {SEG_SIZE}...') print(f'[ i ] Start loop with len = {len(r)}') with concurrent.futures.ProcessPoolExecutor(PROCESS_NO) as executor: while len(r) != 1: newR = [] newM = [] for i in range(0, len(r), SEG_SIZE): if len(r) - i == 1: newR.append(Integer(r[i])) newM.append(Integer(m[i])) else: newR.append(executor.submit(crt, r[i:i+SEG_SIZE], m[i:i+SEG_SIZE])) newM.append(executor.submit(prod, m[i:i+SEG_SIZE])) # Obtain processes' results :3 for i in range(len(newR)): if not isinstance(newR[i], Integer): newR[i] = newR[i].result() newM[i] = newM[i].result() r = newR m = newM if debug: print(f'[ i ] Update loop with len = {len(r)}') if debug: print(f'[ i ] Finished :D') return r[0] # Got CRT # m_65537 = crt(c, n) # <- too slow ! m_65537 = CRT_(c, n, debug=True, SEG_SIZE=8) print(long_to_bytes(Integer(m_65537).nth_root(65537)))

And this is the result on my 8-thread, 4 core machine 👯‍♀️👯‍♀️👯‍♀️👯‍♀️ wow!! 3 times faster.

... and better?

... And we can do even better than this! You see, Sage's crt() only returns the result of the new remainder, it doesn't return the new moduli, although it also calculate the new moduli on the run as well! This makes us having to recalculate it later just to append to our new moduli array, which seems redundant to be honest.

### Snippet from CRT_list() x = v[0] m = moduli[0] from sage.arith.functions import lcm for i in range(1, len(v)): x = CRT(x,v[i],m,moduli[i]) m = lcm(m,moduli[i]) return x % m # Should also return m

Sage's code base is already nice, so I just pull the functions crt() and CRT_list() out from the source and modify it a little bit so that they could return the new moduli as well:

from sage.all import * from sage.structure.coerce import py_scalar_to_element from sage.arith.functions import lcm from Crypto.Util.number import * import concurrent.futures """ /function/ crt_2(): "" Purpose: Modified from Sage's crt() function. Returns CRT value with moduli instead of just CRT. "" Args: a, b: Remainder 1 and 2 m, n: Modulus 1 and 2 """ def crt_2(a, b, m=None, n=None): try: f = (b-a).quo_rem except (TypeError, AttributeError): # Maybe there is no coercion between a and b. # Maybe (b-a) does not have a quo_rem attribute a = py_scalar_to_element(a) b = py_scalar_to_element(b) f = (b-a).quo_rem g, alpha, beta = xgcd(m, n) q, r = f(g) if r != 0: raise ValueError("No solution to crt problem since gcd(%s,%s) does not divide %s-%s" % (m, n, a, b)) x = a + q*alpha*py_scalar_to_element(m) l = lcm(m, n) return x % l, l """ /function/ crt_(): "" Purpose: Modified from Sage's CRT_list() function. Returns CRT value with moduli instead of just CRT. "" Args: r: List of remainders. m: List of modulus. """ def crt_(r, m): res = r[0] prod = m[0] for i in range(1, len(r)): res, prod = crt_2(res, r[i], prod, m[i]) return res % prod, prod """ /function/ CRT_(): "" Purpose: Function that calculates CRT on chunks of <SEG_SIZE> numbers in the array rather than the whole array at once. The goal is to reduce long multiplication time. Works pretty nice with big array and (probably) big number. Test: 65537 1024-bit numbers on a single core -> 3 mins 15 seconds. Test: 65537 1024-bit numbers on 8 cores -> 1 mins 2 seconds. Built on the base of Sage's crt() function. "" Args: r: List of remainders. m: List of modulus. SEG_SIZE=12: Number of values that CRT should work on once at a time. NO_CORES=8: Number of cores in your machine. debug=False: Print out some debug data. """ def CRT_(r: list, m: list, SEG_SIZE=12, NO_CORES=8, debug=False): assert len(r) == len(m) >= 2 assert SEG_SIZE > 1 if debug: print(f'[ i ] Calculate CRT with chunk size {SEG_SIZE}...') print(f'[ i ] Start loop with len = {len(r)}') with concurrent.futures.ProcessPoolExecutor(NO_CORES) as executor: while len(r) != 1: newR = [] newM = [] futures = [] for i in range(0, len(r), SEG_SIZE): if len(r) - i == 1: newR.append(r[i]) newM.append(m[i]) else: futures.append(executor.submit(crt_, r[i:i+SEG_SIZE], m[i:i+SEG_SIZE])) # Obtain processes' results :3 for future in futures: result = future.result() newR.append(result[0]) newM.append(result[1]) r = newR m = newM if debug: print(f'[ i ] Update loop with len = {len(r)}') if debug: print(f'[ i ] Finished :D') return r[0] # Got CRT from c import c from n import n c = c[:65537] n = n[:65537] # m_65537 = crt(c, n) # <- too slow ! m_65537 = CRT_(c, n, debug=True, SEG_SIZE=8, NO_CORES=8) print(long_to_bytes(Integer(m_65537).nth_root(65537)))

With this final piece of code, we have reduced about 20 seconds worth of runtime. Which is actually pretty crazy, considering that the intended solution for this challenge was running the code for about 3 hours by writing optimised CRT() code in Julia instead of Python (confirmed by cothan himself)!! Now, a low-end user can run this algorithm within minutes, seconds even! With a machine running with many more cores, it can perform this task like a breeze!

table-of-content
::: Made with (❤️ && 😭) by Mistsu :::
Your browser sucks, you should consider getting a new one (one that supports display: grid)
>.<