首页 >>  正文

代码教程表

来源:baiyundou.net   日期:2024-09-21

Pine 发自 凹非寺

量子位 | 公众号 QbitAI

现在只用60行代码,就能从0构建GPT了!

想当初,前特斯拉前AI总监的minGPT和nanoGPT也都还要300行代码。

这个60行代码的GPT也有名字,博主将它命名为PicoGPT

不过和此前minGPT和nanoGPT的教程不同,今天要讲的这个博主的教程,更侧重于代码实现部分,模型的权重则用已经训练好的。

对此,博主解释称这篇教程的重点在于提供一个简单且易于破解的完整技术介绍

这对还不理解GPT背后概念的盆友,算是非常友好了。

还有网友称赞,这篇博客介绍得非常清晰,第一部分尤为如此。

这篇介绍GPT模型的文章太好了,它比我之前看到的介绍都要清晰,至少在第一部分讨论文本生成和取样是这样的。

目前,此项目在GitHub上标星已破百,HackerNews上的点击量也即将破千。

从GPT是什么讲起

在介绍之前,还是需要说明一下,这篇教程不是完全零门槛,需要读者提前熟悉Python、NumPy以及一些基本的训练神经网络。

教程的重点聚焦在技术介绍上,统共有六大部分:

什么是GPT?

按照惯例,在正式构建GPT之前得先对它做一些基本介绍,教程从输入/输出、生成文本以及训练三个部分分别来讲GPT是如何工作的。

在这趴,博主附上代码,甚至还用了一些比喻来让读者们更好地理解GPT。

举个栗子

,在输入这一部分,作者将句子比作一条绳子,tokenizer则会将其分割成一小段一小段(单词),被称作token。

又比如说,在生成文本这part介绍自动回归时,博主直接贴上代码:

def generate(inputs, n_tokens_to_generate):

for _ in range(n_tokens_to_generate): # auto-regressive decode loop

output = gpt(inputs) # model forward pass

next_id = np.argmax(output[-1]) # greedy sampling

inputs = np.append(out, [next_id]) # append prediction to input

return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids

input_ids = [1, 0] # "not" "all"

output_ids = generate(input_ids, 3) # output_ids = [2, 4, 6]

output_tokens = [vocab[i] for i in output_ids] # "heroes" "wear" "capes"

在每次迭代中,它会将预测的token追加回输入,这个预测未来值并将其添加回输入的过程就是GPT被描述为自动回归的原因。

60行代码怎么运行?

了解完GPT的基本概念之后,就直接快进到了如何在电脑上运行这个PicoGPT。

博主先是甩出了他那只有60行的代码:

import numpy as np

def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):

pass # TODO: implement this

def generate(inputs, params, n_head, n_tokens_to_generate):

from tqdm import tqdm

for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop

logits = gpt2(inputs, **params, n_head=n_head) # model forward pass

next_id = np.argmax(logits[-1]) # greedy sampling

inputs = np.append(inputs, [next_id]) # append prediction to input

return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids

def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):

from utils import load_encoder_hparams_and_params

# load encoder, hparams, and params from the released open-ai gpt-2 files

encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)

# encode the input string using the BPE tokenizer

input_ids = encoder.encode(prompt)

# make sure we are not surpassing the max sequence length of our model

assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]

# generate output ids

output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)

# decode the ids back into a string

output_text = encoder.decode(output_ids)

return output_text

if name == "__main__":

import fire

fire.Fire(main)

然后从克隆存储库,安装依赖项等步骤一步步教你如何在电脑上运行GPT。

其中,还不乏一些贴心的小tips,比如说如果使用的是M1 Macbook,那在运行pip install之前,需要将requments.txt中的tensorflow更改为tensorflow-macos。

此外,对于代码的四个部分:gpt2,generate,main以及fire.Fire(main),博主也有做详细解释。

等到代码能够运行之后,下一步博主就准备详细介绍编码器、超参数(hparams)以及参数(params)这三部分了。

直接在笔记本或者Python会话中运行下面这个代码:

from utils import load_encoder_hparams_and_params

encoder, hparams, params = load_encoder_hparams_and_params("124M", "models")

Bingo!一些必要的模型和tokenizer文件就直接下载到model/124M,编码器、hparams和params也能直接加载。

更具体的内容这里就不多说了,教程的链接已经附在文末。

一些基础神经网络层的介绍

这一趴涉及到的知识就更加基础了,因为下一趴是实际GPT自身的架构,所以在此之前,需要了解一些非特定于GPT的更基本的神经网络层

博主介绍了GeLU、Softmax函数以及Layer Normalization和Linear。

GPT架构

终于!这部分要来讲GPT自身的架构了,博主从transformer的架构引入。

△transformer架构

GPT的架构只使用了transformer中的解码器堆栈(即图表的右边部分),并且其中的的“交叉注意”层也没有用到。

△GPT架构

随后,博主将GPT的架构总结成了三大部分:

  • 文本 + 位置嵌入
  • 变压器解码器堆栈
  • 下一个token预测头

并且还将这三部分用代码展示了出来,是酱紫的:

def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab]

# token + positional embeddings

x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd]

# forward pass through n_layer transformer blocks

for block in blocks:

x = transformer_block(x, block, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]

# projection to vocab

x = layer_norm(x, ln_f) # [n_seq, n_embd] -> [n_seq, n_embd]

return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab]

再后面,就是关于这三部分的更多细节……

测试构建的GPT

这部分将全部的代码组合在一起,就得到了gpt2.py,统共有120行代码,删除注释和空格的话,就是60行。

然后测试一下!

python gpt2.py \\

"Alan Turing theorized that computers would one day become" \\

--n_tokens_to_generate 8

结果是这样的:

the most powerful machines on the planet.

成功了!

一些后续补充

最后一部分,博主也总结了这短短60行代码的不足:非常低效!

不过他还是给出了两个可以让GPT变高效的方法:

  • 同时地而不是顺序地执行注意力计算。
  • 实现 KV 缓存。

此外,博主还推荐了一些训练模型、评估模型以及改进架构的方法和教程。

感兴趣的话,直接戳文末链接~

作者介绍

Jay Mody,目前在加拿大一家NLP初创公司Cohere从事机器学习的工作,此前,他还分别在特斯拉和亚马逊作为软件工程师实习过一段时间。

除了这篇教程之外,小哥的博客网站上还有更新其他文章,并且都有附代码~

代码传送门:

https://github.com/jaymody/picoGPT/blob/29e78cc52b58ed2c1c483ffea2eb46ff6bdec785/gpt2_pico.py#L3-L58

教程链接:

https://jaykmody.com/blog/gpt-from-scratch/#putting-it-all-together

— 完 —

量子位 QbitAI · 头条号签约

","force_purephv":"0","gnid":"9fc3d7b5cdf28f7f1","img_data":[{"flag":2,"img":[{"desc":"","height":"248","title":"","url":"https://p0.ssl.img.360kuai.com/t01dfa1725bf21f2381.jpg","width":"1080"},{"desc":"","height":"266","title":"","url":"https://p0.ssl.img.360kuai.com/t01abbbe86e793fb603.jpg","width":"758"},{"desc":"","height":"196","title":"","url":"https://p0.ssl.img.360kuai.com/t017c4cef6055a6f906.jpg","width":"768"},{"desc":"","height":"104","title":"","url":"https://p0.ssl.img.360kuai.com/t0108ba73238b4f1cd0.jpg","width":"1066"},{"desc":"","height":"696","title":"","url":"https://p0.ssl.img.360kuai.com/t0167a75c2e2fd0607a.jpg","width":"968"},{"desc":"","height":"72","title":"","url":"https://p0.ssl.img.360kuai.com/t01532333bcb0730fed.jpg","width":"72"},{"desc":"","height":"308","title":"","url":"https://p0.ssl.img.360kuai.com/t018d4412530976d03f.jpg","width":"888"},{"desc":"","height":"400","title":"","url":"https://p0.ssl.img.360kuai.com/t01e291e1f5a317d009.jpg","width":"732"},{"desc":"","height":"1522","title":"","url":"https://p0.ssl.img.360kuai.com/t01904970ee221d4e8c.jpg","width":"1080"},{"desc":"","height":"772","title":"","url":"https://p0.ssl.img.360kuai.com/t01aa3ec31575ccfb86.jpg","width":"348"},{"desc":"","height":"800","title":"","url":"https://p0.ssl.img.360kuai.com/t017d2d2f07cb8e2101.jpg","width":"800"},{"desc":"","height":"784","title":"","url":"https://p0.ssl.img.360kuai.com/t01130eb6abd53f8742.jpg","width":"1080"}]}],"original":0,"pat":"qgc,art_src_3,fts0,sts0","powerby":"hbase","pub_time":1676780100000,"pure":"","rawurl":"http://zm.news.so.com/44381eda78aba92b92e6f3bba0dc9a8d","redirect":0,"rptid":"56704778322b5274","rss_ext":[],"s":"t","src":"量子位","tag":[{"clk":"ktechnology_1:mac","k":"mac","u":""}],"title":"60行代码就能构建GPT!网友:比之前的教程都要清晰|附代码

幸刚涛807Excel VBA代码
臧贪泉15335674347 ______ 提示: 插入模块,编辑一个宏: 1、用 GetObject 函数,将 D:\My Documents\数据备份.xls 文件读取到内存. 2、用语句 Range("A1000").End(xlUp).Row 读取最后一行行号n,假若超过1000行,只要改变1000即可. 3、判断A5单元格是否为...

幸刚涛807简述手工编程步骤 -
臧贪泉15335674347 ______ 1.分析零件图样和工艺要求 分析零件图样和工艺要求的目的,是为了确定加工方法、制定加工计划,以及确认与生产组织有关的问题,此步骤的内容包括: 确定该零件应安排在哪类或哪台机床上进行加工. 采用何种装夹具或何种装卡位方法...

幸刚涛807asp+sql统计表的代码
臧贪泉15335674347 ______ create procedure proc_newaccount @name varchar(10), @pid varchar(20), @telephone varchar(20), @openmoney money, @savingtype varchar(10), @address varchar (50)='' --默认 as declare @error int set @error=0 declare @cardid varchar(...

幸刚涛807SQL代码建表. -
臧贪泉15335674347 ______ 1.create database 学生选课数据库 2.create table 课程表 (课程号 char(6) primary key , 课程名 char(16) not null, 学分 number not null, 先行课程号 char(6)); 3.create table 学生表 (学号 char(6) primary key, 姓名 char(16) not null, 性...

幸刚涛807帮我写一个EXCEL的代码,简单的 -
臧贪泉15335674347 ______ 代码如下: Private Sub Worksheet_Change(ByVal Target As Range) If Target.Row = 6 And Target.Column = 10 Then Cells(8, 8).Select End If End Sub代码不能放在模块中,需要放在工作表中.

幸刚涛807Html代码【table】
臧贪泉15335674347 ______ <table border="0" cellspacing="0" cellpadding="0"> <tr> <td><img name="" src="" width="430" height="430" alt="" /></td> <td><div style="height:430px; overflow:auto; overflow-x:hidden;">描述</div></td> </tr></table> 你自己再定义一下table的宽度就行了...

幸刚涛807excel vba代码解释
臧贪泉15335674347 ______ Range("B65536").End(xlUp).Row '返回B列的最后一个有数据单元格所在的行号,例如100 Range("B2:B" &amp; .Range("B65536").End(xlUp).Row) '选中从B2开始到最后一个有数据单元格的区域范围,例如("B2:B100") SpecialCells(xlCellTypeBlanks) '选中某区域中的空白单元格 整段代码的意思是:将"观察表"中B2单元格的数据粘贴到"二维表"B列中的所有空单元格中(作用范围从B2开始到B列最后一个有数据的单元格).

幸刚涛807C语言——线性表 -
臧贪泉15335674347 ______ #include"stdio.h" #include<malloc.h> typedef char ElemType; typedef struct LNode {ElemType data; struct LNode *next; }LinkList; void CreatListF(LinkList *&L,ElemType a[],int n) //头插法建表 { LinkList *s;int i; L=(LinkList *)malloc(sizeof(...

幸刚涛807编写一个SQLSERVER 存储过程 -
臧贪泉15335674347 ______ 代码是最好的文字,不多说,请看我的代码,并给分,呵呵.--step1. 建表if exists(select * from sysobjects where id=object_id('student') and objectproperty(id,'IsTable...

幸刚涛807excel 宏 代码 -
臧贪泉15335674347 ______ 5分少了,加分吧! 身边没有可用的office程序,试着写下吧: sub macro1() dim s1ce as string '记录表1中的CE列值 dim s2ce as string '记录表2中的CE列值 dim s2r1 as string '记录表2中的顺序行值 dim s2r2 as string '记录表2中的查找行值 ...

(编辑:自媒体)
关于我们 | 客户服务 | 服务条款 | 联系我们 | 免责声明 | 网站地图 @ 白云都 2024