gpt-2のソースを眺めてみる①~適当リファクタ~
今回のお相手
https://github.com/openai/gpt-2
お借りします。
GPT-2の詳しいことを知りたくてこのページを見に来た人すみません。回れ右です。
前回
だいぶ脱線してしまったけどGPT-2についてあっさーいところは理解できた。
対象の選定。
安定のCC値(サイクロマティック複雑度)を採用する。頑張れradon。
ソースは前回gitから拝借している。
ターミナルを起動して、
radon cc D:\git_refact\gpt-2\src -n C
をキックするだけ。
優秀すぎやしませんかね。Cランク以下はEncoderクラスだけの模様。
Encoderクラス、bpeメソッドを対象にすることに決まり。
M 55:4 Encoder.bpe - C
bpeめそっど
def bpe(self, token): if token in self.cache: return self.cache[token] word = tuple(token) pairs = get_pairs(word) if not pairs: return token while True: bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) new_word.extend(word[i:j]) i = j except: new_word.extend(word[i:]) break if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = ' '.join(word) self.cache[token] = word return word
100行程度多少がネストふかいかなーという印象。
コンストラクタがちょっと厄介な感じもするし、どうしてWhile Trueしてしまうのか。僕もやるはやる。
なんとなく、python臭がしないというか、どこか別言語専門の人が書いたようなそんな感じもうける。気のせいかもしれない。
bpeは外部から参照されていないようで、つかいかたとしてはprivateな様子。pythonの記述この変がいいかげん(適切な表現ではないと思うけど)
なの気にはなるなぁ。
とりあえず。一番簡単(気を付けないと危ないけど)なやつ
変数名をかえた。privateを示すように。
def _bpe(self, token):
名づけは有能なリファクタですよね。うん。
テストをつくりはじめる。
class Test_Encoder(unittest.TestCase): _encoder = encoder.Encoder(None, None)
雑にインスタンスを生成。実行するとエラーで落ちる。
コンストラクタを見てみる。
def __init__(self, encoder, bpe_merges, errors='replace'): self.encoder = encoder self.decoder = {v:k for k,v in self.encoder.items()} self.errors = errors # how to handle errors in decoding self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.cache = {} # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
ふむ。enocoderは何らかのルールが必要な様子。けどコメントもなくて全然わからん。
とおもって、同ファイルを眺めていると
def get_encoder(model_name, models_dir): with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f: encoder = json.load(f) with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f: bpe_data = f.read() bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] return Encoder( encoder=encoder, bpe_merges=bpe_merges, )
ファクトリらしきものをはっけん。ディレクトリ指定して
jsonファイルを読み込むようにみえる。このbpeというのはDockerに記載のあったやつか?
Docker内で使われているdownload_model.pyをみても
for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']:
とかある。たぶん推測どおりだろう。まいったな。 テストをするのに特別にファイルとか、環境が必要なのはなんだかいやだなぁ。
異様にくるしいが、Jsonを・・・
ということで、Json形式からitems()を呼び出せるオブジェクトを生成しないといけない。
loadがDict型をつくるようなので、すごく雑に文字列を生成してみる。
def main(self): _encoder = json.loads(self.test_jsonEncoder()) _obj = encoder.Encoder(_encoder, None) def test_jsonEncoder(self): jsonString = "{\n" jsonString += " \"item1\": \"value1\"," jsonString += " \"item2\": \"value2\"" jsonString += "}" return jsonString
エラーがかわった。とりあえずこれで行けるらしい。
他のオブジェクトもとりあえずコンストラクタが通るように生成。
テスト用に整形しなおして
import unittest import encoder import json class Test_Encoder(unittest.TestCase): def test_constractor(self): ins = self.testutil_create_instance() def testutil_create_instance(self): _encoder = json.loads(self.testutil_jsonEncoder()) _bpe = {} return encoder.Encoder(_encoder, _bpe) def testutil_jsonEncoder(self): jsonString = "{\n" jsonString += " \"item1\": \"value1\"," jsonString += " \"item2\": \"value2\"" jsonString += "}" return jsonString if __name__ == "__main__": unittest.main()
そんなに悪くないんじゃないかな。これでひとまず実行はできる。
最後にradonでCC値をはかって
d:\git_refact\gpt-2\src>radon cc D:\git_refact\gpt-2\src -n C
D:\git_refact\gpt-2\src\encoder.py
M 54:4 Encoder._bpe - C
なんで劣化しとんのや・・・とりあえず今日はここまでようやく下準備が終わった感じか。
下々準備かな。