Wednesday, August 31, 2011

Dissecting RSA keys in Python (4)

Still working with RSA keys. I accomplished the most important of the items left undone so far, to use the keys to encrypt a message using my own code.

In the process, I learned a new simple Python fact: the pow function (which is now a built-in), can take a modulus as a third argument.

And, it works where this doesn't seem to: (x**y) % z

(I discovered this by digging into the rsa module source).

python -m timeit -s "import rsa;  \
f = open('id_rsa');  data =; f.close(); k = rsa.PrivateKey.load_pkcs1(data)"\
 "41330915578951772302369**k.e % k.n"

10000 loops, best of 3: 79.4 usec per loop

python -m timeit -s "import rsa;  \
f = open('id_rsa');  data =; f.close(); k = rsa.PrivateKey.load_pkcs1(data)"\
 "pow(41330915578951772302369, k.e, k.n)"

10000 loops, best of 3: 70.2 usec per loop

In the example here, both work and have about the same timing. But in the code below, the first version hangs when doing decryption (with a large base). Here's the output, followed by the script.

> python 
m:   Hello, secret world!
p:   .xyz.Hello, secret world!.xyz.
a:   413309155789517723023698766343791993289928631329
c:   11434905702482726455415220687715293368262190253795 ..
i:   413309155789517723023698766343791993289928631329
r:   Hello, secret world!

import rsa

with open('id_rsa') as f:
    data =
k = rsa.PrivateKey.load_pkcs1(data)
n = k.n
e = k.e
d = k.d

def my_atoi(s):
    L = [ord(c) for c in s]
    k = 256
    iL = L[:]
    x = iL[0]
    for i in iL[1:]:
        x += i*k
        k *= 256
    return x

def my_itoa(i):
    rL = list()
    while i:
        i = i/256
    return ''.join([chr(n) for n in rL])

def encrypt(m):
    return m**e % n
def decrypt(c):
    # note:  c**d % n fails
    return pow(c,d,n)
if __name__ == '__main__':
    m = 'Hello, secret world!'
    pad = '.xyz.'
    p = pad + m + pad
    a = my_atoi(m)
    c = encrypt(a)
    i = decrypt(c)
    r = my_itoa(i)
    r = r.replace(pad,'')

    L = zip('mpacir',[m,p,a,c,i,r])
    N = 50
    for varname, var in L:
        s = str(var)
        print varname + ':  ', s[:N],
        if len(s) > N:  print '..'
        else:  print