調べ物した結果

現役SEが仕事と直接関係ないことを調べた結果とか感想とか

gpt-2のソースを眺めてみる①~適当リファクタ~

今回のお相手

https://github.com/openai/gpt-2
お借りします。
GPT-2の詳しいことを知りたくてこのページを見に来た人すみません。回れ右です。

前回


だいぶ脱線してしまったけどGPT-2についてあっさーいところは理解できた。

対象の選定。


安定のCC値(サイクロマティック複雑度)を採用する。頑張れradon。
ソースは前回gitから拝借している。
ターミナルを起動して、
radon cc D:\git_refact\gpt-2\src -n C
をキックするだけ。

f:id:couraeg:20191121231202p:plain

優秀すぎやしませんかね。Cランク以下はEncoderクラスだけの模様。
Encoderクラス、bpeメソッドを対象にすることに決まり。
M 55:4 Encoder.bpe - C

bpeめそっど

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)

        if not pairs:
            return token

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word


100行程度多少がネストふかいかなーという印象。
コンストラクタがちょっと厄介な感じもするし、どうしてWhile Trueしてしまうのか。僕もやるはやる。
なんとなく、python臭がしないというか、どこか別言語専門の人が書いたようなそんな感じもうける。気のせいかもしれない。
bpeは外部から参照されていないようで、つかいかたとしてはprivateな様子。pythonの記述この変がいいかげん(適切な表現ではないと思うけど)
なの気にはなるなぁ。

とりあえず。一番簡単(気を付けないと危ないけど)なやつ


変数名をかえた。privateを示すように。
def _bpe(self, token):
名づけは有能なリファクタですよね。うん。

テストをつくりはじめる。

class Test_Encoder(unittest.TestCase):
    _encoder = encoder.Encoder(None, None)


雑にインスタンスを生成。実行するとエラーで落ちる。
コンストラクタを見てみる。

    def __init__(self, encoder, bpe_merges, errors='replace'):
        self.encoder = encoder
        self.decoder = {v:k for k,v in self.encoder.items()}
        self.errors = errors # how to handle errors in decoding
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}

        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

ふむ。enocoderは何らかのルールが必要な様子。けどコメントもなくて全然わからん。
とおもって、同ファイルを眺めていると

def get_encoder(model_name, models_dir):
    with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
        encoder = json.load(f)
    with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
        bpe_data = f.read()
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
    return Encoder(
        encoder=encoder,
        bpe_merges=bpe_merges,
    )

ファクトリらしきものをはっけん。ディレクトリ指定して
jsonファイルを読み込むようにみえる。このbpeというのはDockerに記載のあったやつか?
Docker内で使われているdownload_model.pyをみても
for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']:
とかある。たぶん推測どおりだろう。まいったな。 テストをするのに特別にファイルとか、環境が必要なのはなんだかいやだなぁ。

異様にくるしいが、Jsonを・・・


ということで、Json形式からitems()を呼び出せるオブジェクトを生成しないといけない。
loadがDict型をつくるようなので、すごく雑に文字列を生成してみる。

    def main(self):
        _encoder = json.loads(self.test_jsonEncoder())
        _obj = encoder.Encoder(_encoder, None)

    def test_jsonEncoder(self):
        jsonString = "{\n"
        jsonString += "             \"item1\": \"value1\","
        jsonString += "             \"item2\": \"value2\""
        jsonString += "}"
        return jsonString

エラーがかわった。とりあえずこれで行けるらしい。
他のオブジェクトもとりあえずコンストラクタが通るように生成。
テスト用に整形しなおして

import unittest
import encoder
import json

class Test_Encoder(unittest.TestCase):
    def test_constractor(self):
        ins = self.testutil_create_instance()

    def testutil_create_instance(self):
        _encoder = json.loads(self.testutil_jsonEncoder())
        _bpe = {}
        return encoder.Encoder(_encoder, _bpe)

    def testutil_jsonEncoder(self):
        jsonString = "{\n"
        jsonString += "             \"item1\": \"value1\","
        jsonString += "             \"item2\": \"value2\""
        jsonString += "}"
        return jsonString

if __name__ == "__main__":
    unittest.main()

そんなに悪くないんじゃないかな。これでひとまず実行はできる。

最後にradonでCC値をはかって

d:\git_refact\gpt-2\src>radon cc D:\git_refact\gpt-2\src -n C
D:\git_refact\gpt-2\src\encoder.py
M 54:4 Encoder._bpe - C

なんで劣化しとんのや・・・とりあえず今日はここまでようやく下準備が終わった感じか。
下々準備かな。