#!/usr/bin/env python3
import math
from Crypto.Util import number
from Crypto.PublicKey import RSA
from time import time


def extended_euclid(a, b):
    d = -1
    x = -1
    y = -1
    if b == 0:
        d = a
        x = 1
        y = 0
    else:
        x2 = 1
        x1 = 0
        y2 = 0
        y1 = 1
        while b > 0:
            q = abs(a//b)
            r = a - q * b
            x = x2 - q * x1
            y = y2 - q * y1
            # step two
            a = b
            b = r
            x2 = x1
            x1 = x
            y2 = y1
            y1 = y
    d = a
    x = x2
    y = y2
    return (d,x,y)

# generate key components
e = int(pow(2, 16) + 1)
p = number.getStrongPrime(1024, e=e)
q = number.getStrongPrime(1024, e=e)
n = p * q
phi = (p - 1) * (q - 1)
d = number.inverse(e, phi)
print('p', p, '\n q', q, '\n n', n, '\n phi', phi, '\n e', e, '\n d', d)

# test it
key = RSA.construct((n,e,d,p,q))
random = number.getRandomInteger(64)
print('Encrypting random number %i' % random)
encrypted = key.encrypt(random, None)
decrypted = key.decrypt(encrypted)
print('Decrypted value %i' % decrypted)
if decrypted != random:
    print('uh-oh')
else:
    print('Decryption worked!')

cp = number.inverse(q, p) #int(pow(q, -1) % p)
cq = number.inverse(p, q) #int(pow(p, -1) % q)
dp = d % (p - 1)
dq = d % (q - 1)


def decrypt_chinese_rest(x):
    xp = x % p
    xq = x % q
    # 2
    print("decrypt_chinese_rest")
    yp = pow(xp, dp, p)
    yq = pow(xq, dq, q)
    # 3
    print('c done')
    return (q * cp * yp + p * cq * yq) % n


def decrypt_naive(x):
    print("decrypt_naive")
    result = pow(x, d, n)
    print('n done')
    return result


def benchmark_chunky(x):
        results = []
        i = 0
        results.append(benchmark(i, x))
        elapsed_naive = 0.0
        elapsed_chinese = 0.0
        for result in results:
            elapsed_naive = elapsed_naive + result[0]
            elapsed_chinese = elapsed_chinese + result[1]
        print('----------------------------------------------------\n',
              elapsed_naive, elapsed_chinese, elapsed_naive - elapsed_chinese, 1.0 - (elapsed_chinese / elapsed_naive))


def benchmark(i, x):
    encrypt = key.encrypt(x, None)[0]
    start_chinese = time()
    result_chinese = decrypt_chinese_rest(encrypt)
    stop_chinese = time()
    elapsed_chinese = stop_chinese - start_chinese

    start_naive = time()
    result_naive = decrypt_naive(encrypt)
    stop_naive = time()
    elapsed_naive = stop_naive - start_naive

    if result_chinese != x or result_naive != x:
        raise AssertionError('decryption unsuccessful, naive %i, chinese: %i, x %i' % (result_naive ,result_chinese, x))

    print(i + 1, elapsed_naive, elapsed_chinese, elapsed_naive - elapsed_chinese, 1.0 - (elapsed_chinese/elapsed_naive))
    return elapsed_naive, elapsed_chinese

# benchmark
print('\t'.join(['#', 'naive duration', 'chinese duration', 'difference', '%']))
for i in range(50):
    benchmark_chunky(number.getRandomInteger(2047))