top
おさだのホームページ

RSA暗号を作ってみる
サマーウォーズのあの暗号

1. RSA暗号とは

RSA暗号は公開鍵方式といい、受信者が秘密鍵と公開鍵、送信者が公開鍵のみを用いて暗号化・復号化を行う。 秘密鍵は受信者の元から離れることがないので安全であると言われている。


暗号に関わる数字は下記の6個である。
 
p  :  素数
q  :  素数(ただし、p ≠ q)
n  :  p * q
e  :  p-1 と q-1 に対して互いに素な数(ただし e > 2)
L  :  p-1 と q-1 の最小公倍数(単に p-1 * q-1 でも可だが、桁数が大きくなるので最小公倍数を推奨)
d  :  de - yL = 1 の整数解 (d, y) のどれか一つ

上記の数字のうち、n, e が公開鍵d が秘密鍵となる。


平文と暗号文の関係は、
 
平文 ^ e (mod n) = 暗号文
暗号文 ^ d (mod n) = 平文

であるので、①が送信側、②が受信側となる。


例として p = 5, q = 11 の時の公開鍵と秘密鍵を計算すると、
n = 5*11 = 55
e = 3
p-1 = 4 = 2^2, q-1 = 10 = 2 * 5 より、最小公倍数 L = 2^2 * 5 = 20
3d - 20y = 1 より、d = 7, y = 1

よって公開鍵は 55 と 3 で、秘密鍵は 7 になる。


2. まず実直な実装

1節で紹介した仕様通りにPythonを書いた。また、このプログラムではnumpyをあまり利用していない。その理由は6節で後述する。

copy
make_rsa.py
展開
折り畳む
#! /usr/local/bin/env python3
#! encode : -*- utf-8 -*-

import sympy
import random


class RSA:

    def __init__(self, start:int=100, range:int=100, *, sender:bool=False, n:int=None, e:int=None):

        # 送信者用の設定
        if(sender):
            self.n = n
            self.e = e

        # 受信者用の設定
        else:
            # 新たな鍵の生成
            p_buf = random.randint(start, start+range)
            q_buf = random.randint(start, start+range-1)

            if(p_buf == q_buf): q_buf += 1

            # 素数 p
            self.p = sympy.prime(p_buf)
            # 素数 q
            self.q = sympy.prime(q_buf)
            # 公開鍵 n
            self.n = self.p * self.q

            # 公開鍵 e
            facts = sympy.factorint((self.p-1)*(self.q-1))
            buf_e = 1
            while sympy.prime(buf_e) in facts.keys():
                buf_e += 1
            self.e = sympy.prime(buf_e)

            # 不定方程式を構成する変数 L
            self.l = sympy.lcm(self.p-1, self.q-1)

            # 秘密鍵 d
            count = 1
            while ((1+(count*self.l))/self.e) % 1 != 0:
                count += 1
            self.d = (1+(count*self.l))/self.e
            self.y = count


    ## 復号化
    def decode(self, xs:list):
        
        result = []
        for x in xs:
            result.append((x**self.d)%self.n) # 暗号文^d (mod n) = 平文
            print(f"** {x} -> {result[len(result)-1]} **")
        
        # utf8 でデコード
        re_str = bytes(result).decode("utf8")

        return result, re_str
    

    ## 暗号化
    def encode(self, xs:list, n:int=None, e:int=None):
        
        n = self.n if n == None else n
        e = self.e if e == None else e
        result = []
        chars = dict()
        re_str = ""

        # utf8 でエンコード
        xs = list(xs.encode("utf8"))

        for x in xs:
            result.append((x**e)%n) # 平文^e (mod n) = 暗号文

        for r in result:
            re_str += f"{r} "

        return result, re_str


    ## 公開鍵の取得
    def get_keys(self):
        return self.n, self.e


RSA(n, m) で n 〜 n+m 番目までの素数を使った公開鍵と秘密鍵を生成する。また、RSA(sender=True) とすることで鍵を生成することなく暗号化の機能のみを利用できる。

公開鍵の取得:RSA.get_keys()、戻り値:n, e
暗号化:RSA.encode(平文, n, e)、戻り値:暗号文, 暗号文を読みやすくした文字列
復号化:RSA.decode(暗号文)、戻り値:復号文, 復号文を読みやすくした文字列


実行結果
>> recv = RSA(100, 100) # 受信者
... send = RSA(sender=True) # 送信者
...
... n, e = recv.get_keys() # 公開鍵の取得
... print(f"公開鍵 : {n}, {e}")
...
... ## 送信側
... xs = input("入力  > ")
... send_code, send_str = send.encode(xs, n, e) # 暗号化
... print(f"暗号文 : {send_str}\n")
...
... ## 受信側
... recv_code, recv_str = recv.decode(send_code) # 復号化
... print(f"復号文 : {recv_str}")


公開鍵 : 717949, 7
入力  > hello
暗号文 : 585259 475730 472245 472245 188007

** 585259 -> 104 **
** 475730 -> 101 **
** 472245 -> 108 **
** 472245 -> 108 **
** 188007 -> 111 **
復号文 : hello

入力の文字列 "hello" が復元されている。
これでRSA暗号は完成だが、大きい桁の乗剰算を行うために実行時間が長いことが課題である。次の節からはこの計算の高速化について考える。


3. 既解読暗号の再利用

先程の "hello" の暗号化・復号化において " l " は重複しているため、2回同じ計算を行っていることになる。 これを1種類の文字につき必ず1回の計算で済むように修正するには、先に登場する全ての文字列を重複なしで計算しそれを辞書に登録し、その辞書をもとに暗号化・復号化を行う方法が挙げられる。

この考えをもとに所々修正した。(実行方法は変わらない)

copy
クラス定義部分
class RSA:

    def __init__( ...
    ...

        # 受信者用の設定
        else:
            self.chars = dict() # 復号化用の辞書
    ...
copy
decode()の中身
## 復号化
def decode(self, xs:list):

    result = []
    # 辞書の作成
    for x in set(xs):
        self.chars[x] = (x**self.d)%self.n # 暗号文^d (mod n) = 平文
        print(f"** {x} -> {self.chars[x]} **")

    # 辞書をもとに復号化
    for x in xs:
        result.append(self.chars[x])

    # utf8 でデコード
    re_str = bytes(result).decode("utf8")

    return result, re_str
copy
encode()の部分
## 暗号化
def encode(self, xs:list, n:int=None, e:int=None):
    
    if(time_count):
        start = time.time()
    
    n = self.n if n == None else n
    e = self.e if e == None else e
    result = []
    chars = dict() # 暗号化用の辞書
    re_str = ""

    # utf8 でエンコード
    xs = list(xs.encode("utf8"))
    xs_set = set(xs)

    # 辞書の作成
    for x_s in xs_set:
        chars[x_s] = (x_s**e)%n # 平文^e (mod n) = 暗号文

    # 辞書をもとに暗号化
    for x in xs:
        result.append(chars[x])

    for r in result:
        re_str += f"{r} "

    return result, re_str


これで無駄な計算を省くことができた。


4. 並行処理による計算の高速化

3節にてある程度は計算が速くなったが、結局は辞書の作成に一番時間が掛かるので根本的には解決できていない。 一方、この辞書のキーは重複をしていないので何個もの辞書に分割したり、逆にくっつけたりしても壊れることはない。 そこで、この辞書を複数個に分割してそれぞれを並行処理にて計算し、再び一つに統合することで高速化を図ることができる。

つまり、受け取った文字列をいくつかに分割し、それをマルチプロセスにて計算を行う。


話は変わるが、かなり大きい素数で実験を行なったところ暗号化にはそこまでの計算時間を要さず、逆に復号化には小さな素数であってもかなりの計算時間がかかることがわかった。 なのでマルチプロセスによる計算は復号化について行う。


この考えをもとにプログラムを大幅に修正した。(実行方法は変わらない)

copy
import部分
import sympy
import random
import numpy as np
import multiprocessing as mp
copy
decode()の中身
展開
折り畳む
## 復号化
def decode(self, xs:list, split:int=None):

    result = []
    xs_set = set(xs)
    xs_set_buf = xs_set.copy()

    # 一つのスレッドに振り分ける暗号の数
    if(split == None):
        xs_length = len(xs_set)
        if(xs_length == 0):
            return [], ""
        else:
            multi_num = 4 + int(np.log10(len(xs_set) ** 2))
    elif(split > 0):
        multi_num = split
    else:
        multi_num = len(xs_set) + 1
        
    xs_len = len(xs_set)
    xs_div, xs_mod = xs_len//multi_num, xs_len%multi_num

    processes = []
    queues = []

    # 暗号をスレッドに振り分け
    for d in range(xs_div):
        buf = []
        for _ in range(multi_num):
            buf.append(xs_set_buf.pop())
        
        queues.append(mp.Queue())
        processes.append(mp.Process(target=self.calc, args=([buf], queues[d])))

    if(xs_mod != 0):
        buf = []
        for _ in range(xs_mod):
            buf.append(xs_set_buf.pop())

        queues.append(mp.Queue())
        processes.append(mp.Process(target=self.calc, args=([buf], queues[xs_div])))

    # 計算
    for p in range(len(processes)):
        processes[p].start()

    for p in range(len(processes)):
        buf = queues[p].get()
        keys = buf[:len(buf)//2]
        values = buf[len(buf)//2:]

        for k in range(len(keys)):
            self.chars[keys[k]] = values[k]
        processes[p].join()

    for x in xs:
        result.append(self.chars[x])
    
    # utf8 でデコード
    re_str = bytes(result).decode("utf8")

    return result, re_str
copy
RSAクラスに追加した部分
## 剰余の計算
def calc(self, xs:list, queue):
   result = xs[0].copy()
   for x in xs[0]:
        buf = (x**self.d)%self.n
        result.append(buf)
        print(f"** {x} -> {buf} **")

    queue.put(result)


実行した結果は好調であり、前節まで10秒以上掛かっていた計算も1秒程で済むようになった。おそらく自作の関数をmultiprocessing.Process()の中で実行したことも高速化に寄与しているだろう。


5. ちょっと修正して完成

既存の鍵の代入機能、各計算時間の出力、エラーへの対応、repr()への出力を主に修正し、実行例を追記した。(実行方法は変わらない)

copy
make_rsa.py 完成版
展開
折り畳む
#! /usr/local/bin/env python3
#! encode : -*- utf-8 -*-

import sympy
import random
import numpy as np
import multiprocessing as mp
import time


class RSA:

    def __init__(self, start:int=100, range:int=100, *, params:dict=None, sender:bool=False, n:int=None, e:int=None):

        # 送信者用の設定
        if(sender):
            params_buf = {"p":None, "q":None, "n":n, "e":e, "l":None, "d":None, "y":None}
            self.read_params(params_buf)
            self.set_params()

        # 受信者用の設定
        else:
            self.chars = dict()

            # 既存の鍵を設定
            if(params != None):
                self.read_params(params)
                self.set_params()

            # 新たな鍵の生成
            else:
                p_buf = random.randint(start, start+range)
                q_buf = random.randint(start, start+range-1)

                if(p_buf == q_buf): q_buf += 1

                # 素数 p
                self.p = sympy.prime(p_buf)
                # 素数 q
                self.q = sympy.prime(q_buf)
                # 公開鍵 n
                self.n = self.p * self.q

                # 公開鍵 e
                facts = sympy.factorint((self.p-1)*(self.q-1))
                buf_e = 1
                while sympy.prime(buf_e) in facts.keys():
                    buf_e += 1
                self.e = sympy.prime(buf_e)

                # 不定方程式を構成する変数 l
                self.l = sympy.lcm(self.p-1, self.q-1)

                # 秘密鍵 d
                count = 1
                while ((1+(count*self.l))/self.e) % 1 != 0:
                    count += 1
                self.d = (1+(count*self.l))/self.e
                self.y = count
                
                self.set_params()


    ## 復号化
    def decode(self, xs:list, time_count:bool=False, split:int=None, return_time:bool=False):
        
        if(time_count):
            start = time.time()

        result = []
        xs_set = set(xs)
        xs_set_buf = xs_set.copy()

        # 一つのスレッドに振り分ける暗号の数
        if(split == None):
            xs_length = len(xs_set)
            if(xs_length == 0):
                if(return_time):
                    return [], "", 0
                else:
                    return [], ""
            else:
                multi_num = 4 + int(np.log10(len(xs_set) ** 2))
        elif(split > 0):
            multi_num = split
        else:
            multi_num = len(xs_set) + 1
            
        xs_len = len(xs_set)
        xs_div, xs_mod = xs_len//multi_num, xs_len%multi_num

        processes = []
        queues = []

        # 暗号をスレッドに振り分け
        for d in range(xs_div):
            buf = []
            for _ in range(multi_num):
                buf.append(xs_set_buf.pop())
            
            queues.append(mp.Queue())
            processes.append(mp.Process(target=self.calc, args=([buf], queues[d])))

        if(xs_mod != 0):
            buf = []
            for _ in range(xs_mod):
                buf.append(xs_set_buf.pop())

            queues.append(mp.Queue())
            processes.append(mp.Process(target=self.calc, args=([buf], queues[xs_div])))

        # 計算
        for p in range(len(processes)):
            processes[p].start()

        for p in range(len(processes)):
            buf = queues[p].get()
            keys = buf[:len(buf)//2]
            values = buf[len(buf)//2:]

            for k in range(len(keys)):
                self.chars[keys[k]] = values[k]
            processes[p].join()

        for x in xs:
            if(self.chars[x] < 1<<8):
                result.append(self.chars[x])
            else:
                print("[!] 復号化した文字列が不正です。")
                if(return_time):
                    return [], "", 0
                else:
                    return [], ""
        
        # utf8 でデコード
        re_str = bytes(result).decode("utf8")

        if(time_count):
            decode_time = round(time.time() - start, 4)
            if(return_time):
                return result, re_str, decode_time
            else:
                print(f"復号化時間 : {decode_time} 秒")

        return result, re_str
    

    ## 暗号化
    def encode(self, xs:list, n:int=None, e:int=None, is_str:bool=True, time_count:bool=False):
        
        if(time_count):
            start = time.time()
        
        n = self.n if n == None else n
        e = self.e if e == None else e
        result = []
        chars = dict()
        re_str = ""

        # utf8 でエンコード
        xs = list(xs.encode("utf8"))
        xs_set = set(xs)

        for x_s in xs_set:
            chars[x_s] = (x_s**e)%n
        for x in xs:
            result.append(chars[x])

        if(is_str):
            for r in result:
                re_str += f"{r} "

        if(time_count):
            print(f"暗号化時間 : {time.time() - start:.4f} 秒")

        return result, re_str


    ## 剰余の計算
    def calc(self, xs:list, queue):
        result = xs[0].copy()
        for x in xs[0]:
            # buf = np.mod(np.power(x, self.d), self.n) # 整数の精度悪
            buf = (x**self.d)%self.n
            result.append(buf)
            print(f"** {x} -> {buf} **")

        queue.put(result)


    ## 公開鍵の取得
    def get_keys(self):
        return self.n, self.e


    ## 全ての変数を取得
    def get_params(self):
        return self.params


    ## 変数の辞書を作成
    def set_params(self):
        self.params = {"p":self.p, "q":self.q, "n":self.n, "e":self.e, "l":self.l, "d":self.d, "y":self.y}


    ## 既存の変数を設定
    def read_params(self, params):
        self.p = params["p"]
        self.q = params["q"]
        self.n = params["n"]
        self.e = params["e"]
        self.l = params["l"]
        self.d = params["d"]
        self.y = params["y"]
    

    def __repr__(self):
        if(self.p == None):
            return f"RSA(sender=True, n={self.n}, e={self.e})"
        else:
            return f"RSA(params={self.params})"




if(__name__ == "__main__"):
    
    # 既存の鍵を代入するならこっち
    # pre_params = {'p': 1039, 'q': 691, 'n': 717949, 'e': 7, 'l': 119370, 'd': 17053, 'y': 1}
    # recv = RSA(params=pre_params)

    recv = RSA(100, 100) # 受信者
    send = RSA(sender=True) # 送信者

    print(repr(recv))

    n, e = recv.get_keys() # 公開鍵の取得
    print(f"公開鍵 : {n}, {e}")

    ## 送信側
    xs = input("入力  > ")
    send_code, send_str = send.encode(xs, n, e) # 暗号化
    print(f"暗号文 : {send_str}", end="\n\n")

    ## 受信側
    recv_code, recv_str = recv.decode(send_code) # 復号化
    print(f"復号文 : {recv_str}")


decode()関数に return_time=True を渡すことで第3戻り値が計算時間になるので、それを利用して他のプログラムから複数回実行し理想的な引数やプロセスの個数を探索することが可能である。


6. numpyのpower()とmod()はあまり使えない話

   私は当初、計算の高速化=numpyだと考えてそれを使用してコーディングを行なった。しかしエラーの連続であった。 詳しく見てみるとnumpy.mod()とnumpy.power()関数の実行結果がおかしいことに気づき、調べてみるとnumpyで扱うことのできる最大の符号なし整数は2^64 - 1であり、それ以上は精度が削がれるとのことであった。

   今回使用する数値は 暗号文^d が 10万^10万 程の桁数になっていることからわかる通り、2^64 - 1など軽く超えているのでnumpyでは計算できなかったのである。 numpyのバージョンによっては2^64 - 1を超える数値を入力するとエラーが出るようだが、私のnumpyではそのまま実行できてしまったので発見が遅れた。 整数の桁を分けて実行する方法があるみたいだが、面倒なので今回は標準搭載の演算子の「*」と「%」でそのまま記述した。

   また、numpyの代わりにmathモジュールを試してみたところ、計算自体は正確に実行してくれたが計算時間の短縮にはあまりならなかった。


7. サマーウォーズの主人公は何を計算していたのか

   映画『サマーウォーズ』にて主人公が序盤と最終盤に数字の羅列から英文を書き出すシーンがある。 あれがまさにRSA暗号であるが、彼が行っていたのは片方の公開鍵のみでの暗号文の復号化である。 要は、

公開鍵 n と暗号文 y が判明している時、x^e (mod n) = y、x^ed (mod n) = x となる x, e, d の組み合わせの内、x の値が文章になりそうなものを探せ

という問題を解いている。


   まず、上記の連立不定方程式の問題を解く(素因数分解をする)ためにはかなりの時間を要し、今のスパコンをもってしても何千年やら何万年やら掛かるとも言われている。 素因数分解について効果的なアルゴリズムは現在見つかっていないので、一般的にはある数値を素数で割り、それが割り切れれば素因数、割り切れなければ次の素数で割るアルゴリズムで執り行われる。

   RSA暗号の鍵はもともと素数で作られているおかげで共通の因数が存在せず、かつ素因数が巨大であるので素数を具に走査するものではとんでもなく時間が掛かってしまう。 また、素数探索に関してはエラトステネスの篩が有効であるが、秘密鍵 e に対して時間計算量が O(e*log(log(e))) であるので巨大素数に対して実用性があるとは言えない。

   サマーウォーズの劇中では「Shorの因数分解アルゴリズム」が書かれた参考書を主人公が読んでいるシーンが登場する。 このアルゴリズムは目的の整数体の周期性から因数を割り出すもので、その時間計算量が多項式時間(入力サイズ x, 実数 a に対して時間計算量が ax であるもの)であることが知られている。 しかし、素数の積が含まれる問題に対しては超多項式時間(入力サイズ x に対して時間計算量が x^2 や x! などであるもの)に跳ね上がることも知られている。

   今回の問題は思いっきり素数の積の累乗が登場するので超多項式時間となり現実的ではないが、このアルゴリズムは量子ビットと相性が良いとされており、 もしそれを計算可能な量子コンピュータが実現すればRSA暗号を3分余りで解くことができるとも言われている。 主人公の脳内にQPUが存在するなのなら納得である。