作者丨Haoqiang Fan@知乎(已授權)
來源丨https://zhuanlan.zhihu.com/p/686616390
編輯丨極市平台
相信很多人和我一樣都是從「古典深度學習」時代一路走來的,面對當今「LLM才是AI「的時代,有著很多的不適應癥狀。看著那麽多的新論文裏的演算法,想從頭開始擼一遍發現要學習一大堆有的沒的的東西,然後 paper 裏提到的實驗條件還嚇人的高,似乎這個喧囂的新世界和自己有點遙不可及。
那麽,有沒有什麽辦法能在「一無所有」的狀態下做點啥呢?
首先,進入新時代了,要更新好自己的思想,放心大膽的當「調包俠」。而現在的確有很多很科學的包!
其中以 hugging face 的 transformers 為集大成者。實際上,從自娛自樂的 toy example,到一系列還挺有影響力的專案比如 Vicuna,LLAVA,翻開程式碼庫,都能看到那行金光閃閃的
import transformers
甚至,翻遍整個repo,都找不到「網絡結構」寫在哪裏,只有一行
from transformers import LlamaForCausalLM
在等著你。
而在 2024 年了,這些包的安裝也沒那麽「陰間」,直接 pip install 回來的 transformers,accelerate,就是親測能用的(當然,假設你已經把 pytorch、CUDA 的安裝和修bug搞定了)。反正我在玩的時候真就是裝了就能用,沒啥玄學。
對於模型訓練 ,官方文件並沒有一個「最小集」的樣本,不過對著文件琢磨一下還是很容易寫出來的:
嗯,是的,一共就16行,配好兩個物件,然後 Trainer.train() 就成了…… 我第一次用的時候沒配 save_steps 導致跑完了不知道模型存哪了,查了一下文件才搞明白。
其中,dataset物件 就是 torch.utils.data.Dataset,要實作 __ len __ 和 __ getitem __ 的介面,這個自己搞搞就好了。
而要使用一個訓練好的模型,直接
import transformersmodel = transformers.GPT2LMHeadModel.from_pretrained(ifname)
都不用自己去手工維護模型的超參數列,這個庫在 checkpoint 資料夾裏已經自己按約定存好了。
然後肯定有人會提出質疑了,這麽直接搞真的嚴肅麽?
不過,讓我們回想一下為啥 CV 類的模型的訓練程式碼都那麽「復雜」,然後就能發現現在的這種搞法的「科學」之處了:
transformers 裏的 Trainer 整合了一些很科學的預設行為,例如定期存 checkpoint 放到實驗名的資料夾下面,把各種曲線資訊同步到 tensorboard、clearml、wandb 等監控軟件,啥參數都不傳也是可以接受的選擇。
而如果你就是要調參,在 TrainingArguments、Trainer、GPT2Config 等地方一共有 138 個可以傳的參數,以及大量透過調方法來填的參數,能滿足不少的需要。比如,可以傳一個 fp16 = True 來「一鍵」加速訓練,而它背後是 apex.amp 這種庫在支撐。
所以來說,把「靠譜」的庫「整合」在一起,暴露出一組帶有合理預設參數的介面,這個方法論在 2024 年來看的確還是可行的。當層層呼叫的這些庫不「炸」的時候,使用體驗相當的絲滑。
然而,如果「炸」了,或者想搞一些比較深度的 hack (比如客製某個算子在訓練的時候觸發一個神奇的行為),就會發現,在一行看似無辜的網絡定義之上,還摞著層層疊疊的 加速庫、最佳化庫、分布式庫、混合精度庫等等,想去「一層層debug」就成了一件恐怖的事情。
好不容易(?)學會了序列建模這個技能(又稱:transformers 庫的安裝與使用),不整點啥活似乎沒意思。
於是,我用我的筆記電腦(帶3060顯卡)訓練了一個 GPT2-small 級別的模型來做 C++ 程式碼的補全。
我從 github 上找了一些經典的 c/c++ 的 repo(如,linux kernel,gcc,cpython,等),收集了 1G 多的 .zip,從裏面找出所有的 c類原始檔,形成了一個 2.2 GB 的訓練集。
不做tokenization,23萬步,長度 512,批大小 4,訓練大概需要不到一天。
然後搞了一個貪心的補全推薦的方法,試了試,好像也不是完全不能用:
訓練出來的語言模型壓縮率大概是 0.7 bits / byte,比 xz -9 還是要強一點的。
考慮到實作出這樣的效果,從搜尋庫名字到實作出來只需要一天,這還是挺驚人的了。
當然,一旦想從這個 baseline 開始提升效果,那麽 language model 裏各種考量的的「洪流」就要來了:
這就是一個很深很深的坑了,而 copilot 之類就是「登峰造極」之後的產物。想往這方面卷,就會一步步走入「去哪融資,去哪買卡,國產半導體何時崛起」的無盡焦慮中。
不過好的一面在於,如果我們不把自己當做一個 「訓 LLM」 的人,而是當一個 「用 transfomers 庫進行序列建模」 的,就會發現事情也沒那麽壞。雖然幾十M的「小」(對 vision 來說其實不小了)模型並不會表現出「大」模型的一些獨特能力,但是也已經可以做一些很有意義的事情了(我還真沒試過用 CNN 硬懟上面的這個 demo 是什麽效果……)。
刨開訓練 transformer 類模型的一些新的技巧,當前這個時代其實最關鍵的問題就變成了: 手頭有啥好的問題,可以表達成一個序列建模嗎?
如果有,或者原來的某些「老」問題可以這麽表示,那不妨用 transformer 們來試一試,也許就有全新的可能。
新的時代,就應該去擁抱新的方法。
最後,附程式碼補全demo的全套程式碼:訓練
train.py
解釋import transformersimport numpy as npimport torchimport osimport randomimport hashlib class CodeDataset(torch.utils.data.Dataset): def __init__(self, name, ctxlen, totalnum): super().__init__() self.name = name fname = './data/codes.txt' self.content_len = os.path.getsize(fname) self.fin = open(fname, 'rb') self.ctxlen = ctxlen self.length = totalnum def __len__(self): return self.length def __getitem__(self, index): key = int(hashlib.md5((self.name + str(index)).encode('utf8')).hexdigest()[-16:], 16) rnd = random.Random(key) while True: idx = rnd.randrange(self.content_len - self.ctxlen + 256) self.fin.seek(idx) buf = self.fin.read(self.ctxlen + 256) j = 1 while j < 256 and buf[j - 1] != (b'\n')[0] and buf[j-1] != (b'\xff')[0]: j += 1 if j >= 256: continue buf = buf[j:j + self.ctxlen - 1] if b'\xff' in buf[:-1]: continue break vec = np.frombuffer(b'\x00' + buf, dtype = 'uint8').astype('int64') vec = torch.from_numpy(vec) return {'input_ids' : vec, 'labels' : vec}if __name__ == '__main__': import argparse parser = argparse.ArgumentParser( prog = 'train', description = 'train on dataset', ) parser.add_argument('--n_embd', dest = 'n_embd', type = int, default = 768) parser.add_argument('--n_layer', dest = 'n_layer', type = int, default = 12) parser.add_argument('--n_head', dest = 'n_head', type = int, default = 12) parser.add_argument('--bsize', dest = 'bsize', type = int, default = 4) parser.add_argument('--ctxlen', dest = 'ctxlen', type = int, default = 512) parser.add_argument('--accstep', dest = 'accstep', type = int, default = 1) parser.add_argument('--compute', dest = 'compute', type = int, default = 10000) args = parser.parse_args() cfg = transformers.GPT2Config(vocab_size = 256, n_positions = args.ctxlen, n_ctx = args.ctxlen, n_embd = args.n_embd, n_layer = args.n_layer, n_head = args.n_head) model = transformers.GPT2LMHeadModel(cfg) nparam = sum([np.prod(i.shape) for i in model.parameters()]) print('nparam', nparam) model = model.cuda() name = 'code_%dM_%dT_c%d_l%d_e%d_h%d_b%d_a%d'%(nparam // 1000000, args.compute, args.ctxlen, args.n_layer, args.n_embd, args.n_head, args.bsize, args.accstep) print(name) total_train_ops = 10 ** 12 * args.compute total_train_steps = total_train_ops // (nparam * args.bsize * args.accstep * args.ctxlen) print('steps', total_train_steps) train_ds = CodeDataset('train', args.ctxlen, 1048576) eval_ds = CodeDataset('valid', args.ctxlen, 1024) train_cfg = transformers.TrainingArguments( output_dir = name, num_train_epochs = total_train_steps * args.bsize * args.accstep / len(train_ds), save_steps = total_train_steps // 10, gradient_accumulation_steps = args.accstep, per_device_train_batch_size = args.bsize, per_device_eval_batch_size = args.bsize, evaluation_strategy = 'steps', eval_steps = (total_train_steps // 10), logging_strategy = 'steps', logging_steps = (total_train_steps // 100), report_to = 'none', ) trainer = transformers.Trainer( model = model, args = train_cfg, train_dataset = train_ds, eval_dataset = eval_ds, ) trainer.train()
演示
demo.py
解釋import transformersimport hashlibimport randomimport torchimport numpy as npdef sugguest_from_model(model, prefix): with torch.no_grad(): predicts = '' past_key_values = None cum_prob = 1 best_len = 0 best_sugguest = '' while len(predicts) < 32: if len(predicts) == 0: last_choice = np.int64([0] + [ord(i) for i in prefix])[None, :] ret = model.forward( input_ids = torch.from_numpy(last_choice), past_key_values = past_key_values, use_cache = True ) past_key_values = ret.past_key_values logits = ret.logits.cpu().numpy()[0, -1] prob = logits - logits.max(axis=0, keepdims = True) prob = np.exp(prob) prob = prob / np.sum(prob, axis = 0, keepdims = True) cur_predict = prob.argmax() if cur_predict == 10: break last_choice = np.int64([[cur_predict]]) cur_prob = prob[cur_predict] predicts = predicts + chr(cur_predict) cum_prob *= cur_prob if len(predicts) * cum_prob > best_len: best_len = len(predicts) * cum_prob best_sugguest = predicts if min(len(predicts) + 8, 32) * cum_prob < best_len: break return best_sugguestif __name__ == '__main__': import getch import termios ifname = './code_85M_40000T_c512_l12_e768_p2_b4_a1/checkpoint-228040/' model = transformers.GPT2LMHeadModel.from_pretrained(ifname) def demo_type(): print('loaded') attr = termios.tcgetattr(0) old_attr = attr[:] attr[3] = attr[3] & ~ (termios.ECHO | termios.ICANON) termios.tcsetattr(0, termios.TCSANOW, attr) lines = [] cur_line = '' sugguest = '' try: while True: c = getch.getch() if len(sugguest) != 0: print(' '*len(sugguest) + '\b' * len(sugguest), end = '', flush = True) if ord(c) == 27: break if ord(c) == 127: if len(cur_line) > 0: if cur_line[-1] == '\t': print('\b'*8, end = '', flush = True) else: print('\b \b', end = '', flush = True) cur_line = cur_line[:-1] elif c == '\n': print() lines.append(cur_line) cur_line = '' elif c == '\t' and (cur_line.replace('\t','') != ''): cur_line = cur_line + sugguest print(sugguest, end = '', flush = True) else: print(c, end = '', flush = True) cur_line += c def get_suggest(): ctx = cur_line i = len(lines) - 1 while i >= 0 and len(ctx) + len(lines[i]) + 1 < 200: ctx = lines[i] + '\n' + ctx i -= 1 return sugguest_from_model(model, ctx) sugguest = get_suggest() if cur_line.replace('\t','') != '': pass elif cur_line == '' and ord(c) != 127 and len(lines) > 0 and lines[-1].startswith('\t'): tabs = '' for j in range(len(sugguest)): if sugguest[j] == '\t': tabs = tabs + sugguest[j] else: break cur_line = tabs print(tabs, end = '', flush = True) if len(tabs): sugguest = get_suggest() else: sugguest = '' else: sugguest = '' print('\033[0;38;5;248m\033[0;48;5;223m' + sugguest + '\b'*len(sugguest) + '\033[0;0m', end = '', flush = True) except KeyboardInterrupt: pass termios.tcsetattr(0, termios.TCSANOW, old_attr) if len(cur_line): lines.append(cur_line) print() for line in lines: print(line) demo_type()