import math
from Crypto.Util import number
from Crypto.PublicKey import RSA
from time import time

def extended_euclid(a, b):
    d = -1
    x = -1
    y = -1
    if b == 0:
        d = a
        x = 1
        y = 0
    else:
        x2 = 1
        x1 = 0
        y2 = 0
        y1 = 1
        while b > 0:
            q = abs(a//b)
            r = a - q * b
            x = x2 - q * x1
            y = y2 - q * y1
            # step two
            a = b
            b = r
            x2 = x1
            x1 = x
            y2 = y1
            y1 = y
    d = a
    x = x2
    y = y2
    return (d,x,y)

# generate key components
e = int(pow(2, 16) + 1)
p = number.getStrongPrime(1024, e=e)
q = number.getStrongPrime(1024, e=e)
n = p * q
phi = (p - 1) * (q - 1)
d = number.inverse(e, phi)
print('p', p, '\n q', q, '\n n', n, '\n phi', phi, '\n e', e, '\n d', d)

# test it
key = RSA.construct((n,e,d,p,q))
random = number.getRandomInteger(64)
print('Encrypting random number %i' % random)
encrypted = key.encrypt(random, None)
decrypted = key.decrypt(encrypted)
print('Decrypted value %i' % decrypted)
if decrypted != random:
    print('uh-oh')
else:
    print('Decryption worked!')

cp: int = int(pow(q, -1) % p)
cq: int = int(pow(p, -1) % q)
dp: int = d % (p - 1)
dq: int = d % (q - 1)


def decrypt_chinese_rest(x):
    xp: int = x % p
    xq: int = x % q
    # 2
    print('pow pow powpow', type(xq))
    yp: int = pow(xp,dp) % p
    yq: int = pow(xq,dq) % q
    # 3
    print('c done')
    return q * cp * yp + p * cq * yq % n


def decrypt_naive(x):
    result: int = pow(x, d)
    print(type(result))
    result: int = result % n
    print(type(result))
    print('n done')
    return result


def benchmark_chunky(x):
        results = []
        i = 0
        for c in bytearray(x):
            results.append(benchmark(i, c))
            i += 1
        elapsed_naive = 0.0
        elapsed_chinese = 0.0
        for result in results:
            elapsed_naive = elapsed_naive + result[0]
            elapsed_chinese = elapsed_chinese + result[1]
        print('----------------------------------------------------\n',
              elapsed_naive, elapsed_chinese, elapsed_naive - elapsed_chinese, 1.0 - (elapsed_chinese / elapsed_naive))


def benchmark(i, x):
    encrypt = key.encrypt(x, None)[0]
    start_chinese = time()
    result_chinese = decrypt_chinese_rest(encrypt)
    stop_chinese = time()
    elapsed_chinese = stop_chinese - start_chinese

    start_naive = time()
    result_naive = decrypt_naive(encrypt)
    stop_naive = time()
    elapsed_naive = stop_naive - start_naive

    if result_chinese != x or result_naive != x:
        raise AssertionError('decryption unsuccessful, naive %i, chinese: %i' % (result_naive ,result_chinese))

    print(i + 1, elapsed_naive, elapsed_chinese, elapsed_naive - elapsed_chinese, 1.0 - (elapsed_chinese/elapsed_naive))
    return elapsed_naive, elapsed_chinese

# benchmark
print('\t'.join(['#', 'naive duration', 'chinese duration', 'difference', '%']))
for i in range(50):
    benchmark_chunky(number.getRandomInteger(2048))