当前位置: 华文世界 > 科技

有手就能训——给AI「老人」写的transformer序列建模指南

2024-03-29科技

作者丨Haoqiang Fan@知乎(已授权)

来源丨https://zhuanlan.zhihu.com/p/686616390

编辑丨极市平台

相信很多人和我一样都是从「古典深度学习」时代一路走来的,面对当今「LLM才是AI「的时代,有着很多的不适应症状。看着那么多的新论文里的算法,想从头开始撸一遍发现要学习一大堆有的没的的东西,然后 paper 里提到的实验条件还吓人的高,似乎这个喧嚣的新世界和自己有点遥不可及。

那么,有没有什么办法能在「一无所有」的状态下做点啥呢?

首先,进入新时代了,要更新好自己的思想,放心大胆的当「调包侠」。而现在的确有很多很科学的包!

其中以 hugging face 的 transformers 为集大成者。实际上,从自娱自乐的 toy example,到一系列还挺有影响力的项目比如 Vicuna,LLAVA,翻开代码库,都能看到那行金光闪闪的

import transformers

甚至,翻遍整个repo,都找不到「网络结构」写在哪里,只有一行

from transformers import LlamaForCausalLM

在等着你。

而在 2024 年了,这些包的安装也没那么「阴间」,直接 pip install 回来的 transformers,accelerate,就是亲测能用的(当然,假设你已经把 pytorch、CUDA 的安装和修bug搞定了)。反正我在玩的时候真就是装了就能用,没啥玄学。

对于模型训练 ,官方文档并没有一个「最小集」的样本,不过对着文档琢磨一下还是很容易写出来的:

嗯,是的,一共就16行,配好两个对象,然后 Trainer.train() 就成了…… 我第一次用的时候没配 save_steps 导致跑完了不知道模型存哪了,查了一下文档才搞明白。

其中,dataset对象 就是 torch.utils.data.Dataset,要实现 __ len __ 和 __ getitem __ 的接口,这个自己搞搞就好了。

而要使用一个训练好的模型,直接

import transformersmodel = transformers.GPT2LMHeadModel.from_pretrained(ifname)

都不用自己去手工维护模型的超参数列表,这个库在 checkpoint 文件夹里已经自己按约定存好了。

然后肯定有人会提出质疑了,这么直接搞真的严肃么?
不过,让我们回想一下为啥 CV 类的模型的训练代码都那么「复杂」,然后就能发现现在的这种搞法的「科学」之处了:

  • 数据集 。CV 类模型往往要使用大量的压缩(jpg/p64)后的数据并进行大量的增广来进行训练,训练过程中总的 IO 吞吐是恐怖的。于是乎几乎人人都要会写分布式 data provider 框架,甚至还要自己搞专门的数据提供的加速办法(例如,多个实验共享data provider)。然而,在 NLP 领域,就算是大如 GPT level 的模型,也就几 T tokens 就能训明白了,弄个 SSD 本地存一下挺香的。。。最多是 tokenizer 之类的预处理部分折腾一下。因为每 token 的训练 flo 的巨大区别,在这个地方花费的精力是可以极大的节省的。
  • 调参 。CNN 类的模型,好多地方调一调参还是挺有用的。而来到了 2024 年的 transformers,模型一方面很「简单」,大量的 transformer block 的堆叠,想乱改涨点还不好涨;另一方面又很「复杂 」,面对着跑一次要花很多 $ 的基础模型调参实验(比如你想把gelu换回relu?),穷人也就不用去想着折腾了。开源世界调出来的那一组「默认参数」,其实还挺 work 的,于是乎就成了一个「非必要别乱调」的格局。
  • 表示 。搞 Vision 的模型「如何表示输入输出」是一门大学问,改一点点影响巨大。而搞序列建模的时候……输入输出就是 tokens 的序列,torch.LongTensor,没了。。管你的 tokens 里是画了一棵树,建了一张图,还是存了一大段 json ,在 NN 这边看来,都(至少在观念上)是一样一样的。
  • 部署 。「古典」AI模型的一大主题是如何在端上跑,所以就产生了一大堆的模型打包、静态化、计算图转化等等的 concern。不过在「新时代」,还远没有卷到这个地步,能 python 跑就还算可以了,这也就减少了大量为「训推一体」付出的精力。
  • transformers 里的 Trainer 集成了一些很科学的默认行为,例如定期存 checkpoint 放到实验名的文件夹下面,把各种曲线信息同步到 tensorboard、clearml、wandb 等监控软件,啥参数都不传也是可以接受的选择。

    而如果你就是要调参,在 TrainingArguments、Trainer、GPT2Config 等地方一共有 138 个可以传的参数,以及大量通过调方法来填的参数,能满足不少的需要。比如,可以传一个 fp16 = True 来「一键」加速训练,而它背后是 apex.amp 这种库在支撑。

    所以来说,把「靠谱」的库「集成」在一起,暴露出一组带有合理默认参数的接口,这个方法论在 2024 年来看的确还是可行的。当层层调用的这些库不「炸」的时候,使用体验相当的丝滑。

    然而,如果「炸」了,或者想搞一些比较深度的 hack (比如定制某个算子在训练的时候触发一个神奇的行为),就会发现,在一行看似无辜的网络定义之上,还摞着层层叠叠的 加速库、优化库、分布式库、混合精度库等等,想去「一层层debug」就成了一件恐怖的事情。

    好不容易(?)学会了序列建模这个技能(又称:transformers 库的安装与使用),不整点啥活似乎没意思。

    于是,我用我的笔记本电脑(带3060显卡)训练了一个 GPT2-small 级别的模型来做 C++ 代码的补全。

    我从 github 上找了一些经典的 c/c++ 的 repo(如,linux kernel,gcc,cpython,等),收集了 1G 多的 .zip,从里面找出所有的 c类源文件,形成了一个 2.2 GB 的训练集。

    不做tokenization,23万步,长度 512,批大小 4,训练大概需要不到一天。
    然后搞了一个贪心的补全推荐的方法,试了试,好像也不是完全不能用:

    视频加载中...

    训练出来的语言模型压缩率大概是 0.7 bits / byte,比 xz -9 还是要强一点的。

    考虑到实现出这样的效果,从搜索库名字到实现出来只需要一天,这还是挺惊人的了。
    当然,一旦想从这个 baseline 开始提升效果,那么 language model 里各种考量的的「洪流」就要来了:

  • 用 tokenizer 进行文本的预压缩,提升 token 的效率(window size是一种昂贵的资源),同时让网络不用去学习各种「显然」的东西。各种 tokenizer 以及训练加速的方法就可以选一选了
  • 清洗数据,减少类似训练数据被 linux kernel 里的 #define 淹没带来的副作用,提升「高质量」数据的占比
  • 更大的模型!找更多更好的卡!另外还要搞预训练
  • 最优的批大小、并行数、learning rate调整,对着 loss 曲线「围炉炼丹」
  • 把 suffix 预测改为 infix 预测,加入文件名、函数名等的各种 hint
  • 考虑用户反馈数据的利用,考虑加入用户问答,考虑针对用户的 finetune
  • 等等等等……
  • 这就是一个很深很深的坑了,而 copilot 之类就是「登峰造极」之后的产物。想往这方面卷,就会一步步走入「去哪融资,去哪买卡,国产半导体何时崛起」的无尽焦虑中。

    不过好的一面在于,如果我们不把自己当做一个 「训 LLM」 的人,而是当一个 「用 transfomers 库进行序列建模」 的,就会发现事情也没那么坏。虽然几十M的「小」(对 vision 来说其实不小了)模型并不会表现出「大」模型的一些独特能力,但是也已经可以做一些很有意义的事情了(我还真没试过用 CNN 硬怼上面的这个 demo 是什么效果……)。

    刨开训练 transformer 类模型的一些新的技巧,当前这个时代其实最关键的问题就变成了: 手头有啥好的问题,可以表达成一个序列建模吗?

    如果有,或者原来的某些「老」问题可以这么表示,那不妨用 transformer 们来试一试,也许就有全新的可能。

    新的时代,就应该去拥抱新的方法。

    最后,附代码补全demo的全套代码:训练

    train.py

    解释import transformersimport numpy as npimport torchimport osimport randomimport hashlib class CodeDataset(torch.utils.data.Dataset): def __init__(self, name, ctxlen, totalnum): super().__init__() self.name = name fname = './data/codes.txt' self.content_len = os.path.getsize(fname) self.fin = open(fname, 'rb') self.ctxlen = ctxlen self.length = totalnum def __len__(self): return self.length def __getitem__(self, index): key = int(hashlib.md5((self.name + str(index)).encode('utf8')).hexdigest()[-16:], 16) rnd = random.Random(key) while True: idx = rnd.randrange(self.content_len - self.ctxlen + 256) self.fin.seek(idx) buf = self.fin.read(self.ctxlen + 256) j = 1 while j < 256 and buf[j - 1] != (b'\n')[0] and buf[j-1] != (b'\xff')[0]: j += 1 if j >= 256: continue buf = buf[j:j + self.ctxlen - 1] if b'\xff' in buf[:-1]: continue break vec = np.frombuffer(b'\x00' + buf, dtype = 'uint8').astype('int64') vec = torch.from_numpy(vec) return {'input_ids' : vec, 'labels' : vec}if __name__ == '__main__': import argparse parser = argparse.ArgumentParser( prog = 'train', description = 'train on dataset', ) parser.add_argument('--n_embd', dest = 'n_embd', type = int, default = 768) parser.add_argument('--n_layer', dest = 'n_layer', type = int, default = 12) parser.add_argument('--n_head', dest = 'n_head', type = int, default = 12) parser.add_argument('--bsize', dest = 'bsize', type = int, default = 4) parser.add_argument('--ctxlen', dest = 'ctxlen', type = int, default = 512) parser.add_argument('--accstep', dest = 'accstep', type = int, default = 1) parser.add_argument('--compute', dest = 'compute', type = int, default = 10000) args = parser.parse_args() cfg = transformers.GPT2Config(vocab_size = 256, n_positions = args.ctxlen, n_ctx = args.ctxlen, n_embd = args.n_embd, n_layer = args.n_layer, n_head = args.n_head) model = transformers.GPT2LMHeadModel(cfg) nparam = sum([np.prod(i.shape) for i in model.parameters()]) print('nparam', nparam) model = model.cuda() name = 'code_%dM_%dT_c%d_l%d_e%d_h%d_b%d_a%d'%(nparam // 1000000, args.compute, args.ctxlen, args.n_layer, args.n_embd, args.n_head, args.bsize, args.accstep) print(name) total_train_ops = 10 ** 12 * args.compute total_train_steps = total_train_ops // (nparam * args.bsize * args.accstep * args.ctxlen) print('steps', total_train_steps) train_ds = CodeDataset('train', args.ctxlen, 1048576) eval_ds = CodeDataset('valid', args.ctxlen, 1024) train_cfg = transformers.TrainingArguments( output_dir = name, num_train_epochs = total_train_steps * args.bsize * args.accstep / len(train_ds), save_steps = total_train_steps // 10, gradient_accumulation_steps = args.accstep, per_device_train_batch_size = args.bsize, per_device_eval_batch_size = args.bsize, evaluation_strategy = 'steps', eval_steps = (total_train_steps // 10), logging_strategy = 'steps', logging_steps = (total_train_steps // 100), report_to = 'none', ) trainer = transformers.Trainer( model = model, args = train_cfg, train_dataset = train_ds, eval_dataset = eval_ds, ) trainer.train()

    演示
    demo.py

    解释import transformersimport hashlibimport randomimport torchimport numpy as npdef sugguest_from_model(model, prefix): with torch.no_grad(): predicts = '' past_key_values = None cum_prob = 1 best_len = 0 best_sugguest = '' while len(predicts) < 32: if len(predicts) == 0: last_choice = np.int64([0] + [ord(i) for i in prefix])[None, :] ret = model.forward( input_ids = torch.from_numpy(last_choice), past_key_values = past_key_values, use_cache = True ) past_key_values = ret.past_key_values logits = ret.logits.cpu().numpy()[0, -1] prob = logits - logits.max(axis=0, keepdims = True) prob = np.exp(prob) prob = prob / np.sum(prob, axis = 0, keepdims = True) cur_predict = prob.argmax() if cur_predict == 10: break last_choice = np.int64([[cur_predict]]) cur_prob = prob[cur_predict] predicts = predicts + chr(cur_predict) cum_prob *= cur_prob if len(predicts) * cum_prob > best_len: best_len = len(predicts) * cum_prob best_sugguest = predicts if min(len(predicts) + 8, 32) * cum_prob < best_len: break return best_sugguestif __name__ == '__main__': import getch import termios ifname = './code_85M_40000T_c512_l12_e768_p2_b4_a1/checkpoint-228040/' model = transformers.GPT2LMHeadModel.from_pretrained(ifname) def demo_type(): print('loaded') attr = termios.tcgetattr(0) old_attr = attr[:] attr[3] = attr[3] & ~ (termios.ECHO | termios.ICANON) termios.tcsetattr(0, termios.TCSANOW, attr) lines = [] cur_line = '' sugguest = '' try: while True: c = getch.getch() if len(sugguest) != 0: print(' '*len(sugguest) + '\b' * len(sugguest), end = '', flush = True) if ord(c) == 27: break if ord(c) == 127: if len(cur_line) > 0: if cur_line[-1] == '\t': print('\b'*8, end = '', flush = True) else: print('\b \b', end = '', flush = True) cur_line = cur_line[:-1] elif c == '\n': print() lines.append(cur_line) cur_line = '' elif c == '\t' and (cur_line.replace('\t','') != ''): cur_line = cur_line + sugguest print(sugguest, end = '', flush = True) else: print(c, end = '', flush = True) cur_line += c def get_suggest(): ctx = cur_line i = len(lines) - 1 while i >= 0 and len(ctx) + len(lines[i]) + 1 < 200: ctx = lines[i] + '\n' + ctx i -= 1 return sugguest_from_model(model, ctx) sugguest = get_suggest() if cur_line.replace('\t','') != '': pass elif cur_line == '' and ord(c) != 127 and len(lines) > 0 and lines[-1].startswith('\t'): tabs = '' for j in range(len(sugguest)): if sugguest[j] == '\t': tabs = tabs + sugguest[j] else: break cur_line = tabs print(tabs, end = '', flush = True) if len(tabs): sugguest = get_suggest() else: sugguest = '' else: sugguest = '' print('\033[0;38;5;248m\033[0;48;5;223m' + sugguest + '\b'*len(sugguest) + '\033[0;0m', end = '', flush = True) except KeyboardInterrupt: pass termios.tcsetattr(0, termios.TCSANOW, old_attr) if len(cur_line): lines.append(cur_line) print() for line in lines: print(line) demo_type()