#! /usr/local/bin/env python3
#! encode : -*- utf-8 -*-
import sympy
import random
import numpy as np
import multiprocessing as mp
import time
class RSA:
def __init__(self, start:int=100, range:int=100, *, params:dict=None, sender:bool=False, n:int=None, e:int=None):
# 送信者用の設定
if(sender):
params_buf = {"p":None, "q":None, "n":n, "e":e, "l":None, "d":None, "y":None}
self.read_params(params_buf)
self.set_params()
# 受信者用の設定
else:
self.chars = dict()
# 既存の鍵を設定
if(params != None):
self.read_params(params)
self.set_params()
# 新たな鍵の生成
else:
p_buf = random.randint(start, start+range)
q_buf = random.randint(start, start+range-1)
if(p_buf == q_buf): q_buf += 1
# 素数 p
self.p = sympy.prime(p_buf)
# 素数 q
self.q = sympy.prime(q_buf)
# 公開鍵 n
self.n = self.p * self.q
# 公開鍵 e
facts = sympy.factorint((self.p-1)*(self.q-1))
buf_e = 1
while sympy.prime(buf_e) in facts.keys():
buf_e += 1
self.e = sympy.prime(buf_e)
# 不定方程式を構成する変数 l
self.l = sympy.lcm(self.p-1, self.q-1)
# 秘密鍵 d
count = 1
while ((1+(count*self.l))/self.e) % 1 != 0:
count += 1
self.d = (1+(count*self.l))/self.e
self.y = count
self.set_params()
## 復号化
def decode(self, xs:list, time_count:bool=False, split:int=None, return_time:bool=False):
if(time_count):
start = time.time()
result = []
xs_set = set(xs)
xs_set_buf = xs_set.copy()
# 一つのスレッドに振り分ける暗号の数
if(split == None):
xs_length = len(xs_set)
if(xs_length == 0):
if(return_time):
return [], "", 0
else:
return [], ""
else:
multi_num = 4 + int(np.log10(len(xs_set) ** 2))
elif(split > 0):
multi_num = split
else:
multi_num = len(xs_set) + 1
xs_len = len(xs_set)
xs_div, xs_mod = xs_len//multi_num, xs_len%multi_num
processes = []
queues = []
# 暗号をスレッドに振り分け
for d in range(xs_div):
buf = []
for _ in range(multi_num):
buf.append(xs_set_buf.pop())
queues.append(mp.Queue())
processes.append(mp.Process(target=self.calc, args=([buf], queues[d])))
if(xs_mod != 0):
buf = []
for _ in range(xs_mod):
buf.append(xs_set_buf.pop())
queues.append(mp.Queue())
processes.append(mp.Process(target=self.calc, args=([buf], queues[xs_div])))
# 計算
for p in range(len(processes)):
processes[p].start()
for p in range(len(processes)):
buf = queues[p].get()
keys = buf[:len(buf)//2]
values = buf[len(buf)//2:]
for k in range(len(keys)):
self.chars[keys[k]] = values[k]
processes[p].join()
for x in xs:
if(self.chars[x] < 1<<8):
result.append(self.chars[x])
else:
print("[!] 復号化した文字列が不正です。")
if(return_time):
return [], "", 0
else:
return [], ""
# utf8 でデコード
re_str = bytes(result).decode("utf8")
if(time_count):
decode_time = round(time.time() - start, 4)
if(return_time):
return result, re_str, decode_time
else:
print(f"復号化時間 : {decode_time} 秒")
return result, re_str
## 暗号化
def encode(self, xs:list, n:int=None, e:int=None, is_str:bool=True, time_count:bool=False):
if(time_count):
start = time.time()
n = self.n if n == None else n
e = self.e if e == None else e
result = []
chars = dict()
re_str = ""
# utf8 でエンコード
xs = list(xs.encode("utf8"))
xs_set = set(xs)
for x_s in xs_set:
chars[x_s] = (x_s**e)%n
for x in xs:
result.append(chars[x])
if(is_str):
for r in result:
re_str += f"{r} "
if(time_count):
print(f"暗号化時間 : {time.time() - start:.4f} 秒")
return result, re_str
## 剰余の計算
def calc(self, xs:list, queue):
result = xs[0].copy()
for x in xs[0]:
# buf = np.mod(np.power(x, self.d), self.n) # 整数の精度悪
buf = (x**self.d)%self.n
result.append(buf)
print(f"** {x} -> {buf} **")
queue.put(result)
## 公開鍵の取得
def get_keys(self):
return self.n, self.e
## 全ての変数を取得
def get_params(self):
return self.params
## 変数の辞書を作成
def set_params(self):
self.params = {"p":self.p, "q":self.q, "n":self.n, "e":self.e, "l":self.l, "d":self.d, "y":self.y}
## 既存の変数を設定
def read_params(self, params):
self.p = params["p"]
self.q = params["q"]
self.n = params["n"]
self.e = params["e"]
self.l = params["l"]
self.d = params["d"]
self.y = params["y"]
def __repr__(self):
if(self.p == None):
return f"RSA(sender=True, n={self.n}, e={self.e})"
else:
return f"RSA(params={self.params})"
if(__name__ == "__main__"):
# 既存の鍵を代入するならこっち
# pre_params = {'p': 1039, 'q': 691, 'n': 717949, 'e': 7, 'l': 119370, 'd': 17053, 'y': 1}
# recv = RSA(params=pre_params)
recv = RSA(100, 100) # 受信者
send = RSA(sender=True) # 送信者
print(repr(recv))
n, e = recv.get_keys() # 公開鍵の取得
print(f"公開鍵 : {n}, {e}")
## 送信側
xs = input("入力 > ")
send_code, send_str = send.encode(xs, n, e) # 暗号化
print(f"暗号文 : {send_str}", end="\n\n")
## 受信側
recv_code, recv_str = recv.decode(send_code) # 復号化
print(f"復号文 : {recv_str}")