动手学深度学习——BERT代码

张开发
2026/6/14 11:19:00 15 分钟阅读
动手学深度学习——BERT代码
1. 前言上一篇我们已经从整体上理解了BERT它是基于 Transformer Encoder 的双向预训练语言模型它采用“预训练 微调”的范式它的核心预训练任务包括MLMMasked Language ModelingNSPNext Sentence Prediction它特别适合各种自然语言理解任务这一篇就继续按李沐的思路把 BERT 真正落到代码层面。这一节最重要的不是一下子把整个预训练过程全写完而是先看清楚 BERT 这个模型本身到底由哪些部分组成输入表示怎么构造BERT 编码器层怎么堆叠MLM 预测头怎么接NSP 分类头怎么接前向传播输出到底有哪些东西如果一句话概括这一节代码的核心那就是BERT 输入嵌入层 多层 Transformer 编码器 预训练任务头2. BERT 代码要解决什么问题如果从实现角度看BERT 这节代码主要解决三件事第一构造输入表示BERT 输入不是单纯 token embedding而是token embeddingsegment embeddingposition embedding三者相加。第二搭建深层 Transformer Encoder让输入序列经过多层自注意力和前馈网络得到上下文化表示。第三接上预训练任务输出头包括MLM 头预测被 mask 的 tokenNSP 头判断两句是否连续所以这一节并不是直接训练而是在把BERT 的骨架和输出接口先搭完整。3. BERT 的输入为什么比普通模型复杂前面我们学 RNN、GRU、LSTM、Seq2Seq 时输入往往主要是token 索引或者 token embedding但 BERT 不一样。因为它不仅要处理单句还常常要处理句对任务并且 Transformer 本身没有顺序递推结构所以还必须显式加入位置信息。因此BERT 输入通常由三部分组成3.1 Token Embedding表示词本身是谁。3.2 Segment Embedding表示这个 token 属于句子 A 还是句子 B。3.3 Position Embedding表示这个 token 在序列中的位置。所以 BERT 输入层本质上比前面的模型更“结构化”。4. BERT 输入嵌入层通常怎么写李沐这里常见的 BERT 编码器初始化大致会写成这样class BERTEncoder(nn.Module): def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len1000, key_size768, query_size768, value_size768, **kwargs): super(BERTEncoder, self).__init__(**kwargs) self.token_embedding nn.Embedding(vocab_size, num_hiddens) self.segment_embedding nn.Embedding(2, num_hiddens) self.blks nn.Sequential() for i in range(num_layers): self.blks.add_module( f{i}, d2l.EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, True) ) self.pos_embedding nn.Parameter(torch.randn(1, max_len, num_hiddens))这段代码基本就把 BERT 编码器的主体结构全展示出来了。5. 为什么token_embedding很好理解这一句self.token_embedding nn.Embedding(vocab_size, num_hiddens)和前面很多 NLP 模型一样作用就是把 token 索引变成稠密向量。例如deeplearning[CLS][SEP]这些 token 本身是整数编号经过 embedding 以后变成num_hiddens维的向量表示。这是所有输入表示的基础。6. 为什么segment_embedding只需要 2这一句self.segment_embedding nn.Embedding(2, num_hiddens)中的2表示只区分两种 segment通常是句子 A句子 B例如在句对输入中[CLS] 句子A [SEP] 句子B [SEP]那么句子 A 对应 segment id 0句子 B 对应 segment id 1所以 segment embedding 只需要 2 种类型就够了。它的作用是告诉模型当前 token 属于哪一段句子这对句对任务很重要。7. 为什么 BERT 的位置编码这里是可学习参数这一句self.pos_embedding nn.Parameter(torch.randn(1, max_len, num_hiddens))表示BERT 这里使用的是可学习的位置编码也就是说位置向量不是固定公式生成的而是直接作为模型参数参与训练。这和最原始 Transformer 论文里常见的正弦位置编码不同。BERT 采用可学习位置嵌入也是很经典的做法。它的含义很简单序列第 1 个位置有一个向量第 2 个位置有另一个向量…这些位置向量会在训练中自动优化8. 为什么 BERT 需要多层EncoderBlock这一段self.blks nn.Sequential() for i in range(num_layers): self.blks.add_module(... d2l.EncoderBlock(...))说明 BERT 编码器并不是一层 Transformer而是多层 Transformer Encoder Block 堆叠这和前面讲“深层循环神经网络”很像只不过这里堆叠的不是 RNN / GRU / LSTM而是多头自注意力残差连接LayerNorm前馈网络共同组成的 Transformer 编码块。所以 BERT 强并不只是因为它有自注意力还因为它是多层深度 Transformer 编码器9. BERT 编码器前向传播怎么写常见写法如下def forward(self, tokens, segments, valid_lens): X self.token_embedding(tokens) self.segment_embedding(segments) X X self.pos_embedding.data[:, :X.shape[1], :] for blk in self.blks: X blk(X, valid_lens) return X这段代码非常关键因为它把 BERT 编码器的数据流真正串起来了。10. 为什么一开始要把三种 embedding 相加这一句的核心是X token_embedding segment_embedding position_embedding这意味着每个位置的最终输入表示不是单一来源而是三种信息的叠加。具体来说token embedding告诉模型这个词本身是谁。segment embedding告诉模型它属于句子 A 还是句子 B。position embedding告诉模型它在整个序列中的哪个位置。所以 BERT 输入层的目标就是把词信息 句段信息 位置信息统一融合成一个向量表示。11.for blk in self.blks: X blk(X, valid_lens)在干什么这表示把输入依次送过多层 Transformer Encoder Block。每一层都会做自注意力前馈网络残差连接LayerNorm经过多层之后X中每个位置的表示都会变成深层上下文化表示也就是说到了最后输出时[CLS]位置表示的是整段综合信息普通 token 位置表示的是结合上下文后的 token 表示这正是 BERT 的核心输出。12. BERT 编码器输出的X到底是什么最终返回的X形状通常是(batch_size, num_steps, num_hiddens)它表示整个输入序列中每个位置的上下文化表示注意这已经不是原始 embedding 了而是经过多层 Transformer 编码之后的结果。这个输出后面会被分别送给MLM 头NSP 头或者下游微调任务头所以 BERT 的主体输出本质上就是高质量上下文化 token 表示序列13. MLM 预测头为什么单独写一个模块因为 MLM 的任务不是预测所有位置而是只预测那些被选中的 mask 位置。所以通常会专门写一个MaskLM类例如class MaskLM(nn.Module): def __init__(self, vocab_size, num_hiddens, num_inputs768, **kwargs): super(MaskLM, self).__init__(**kwargs) self.mlp nn.Sequential( nn.Linear(num_inputs, num_hiddens), nn.ReLU(), nn.LayerNorm(num_hiddens), nn.Linear(num_hiddens, vocab_size) )这说明 MLM 不是直接拿编码器输出硬做分类而是会再经过一个小型预测头。14. 为什么 MLM 头里还要有 MLP因为 BERT 主体输出的是上下文化表示而 MLM 任务需要的是对词表中所有 token 的预测分布中间加一个小 MLP有几个好处第一增加表达能力让 MLM 预测头更灵活不只是一个线性映射。第二更贴近预训练目标让模型在输出到词表前再做一次非线性变换。第三和原始 BERT 结构一致原版 BERT 的 MLM head 本来也不是最简单单层线性。所以这里写成一个小的投影头是合理的。15. MLM 前向传播为什么只取被 mask 的位置常见写法大致如下def forward(self, X, pred_positions): num_pred_positions pred_positions.shape[1] pred_positions pred_positions.reshape(-1) batch_size X.shape[0] batch_idx torch.arange(0, batch_size) batch_idx torch.repeat_interleave(batch_idx, num_pred_positions) masked_X X[batch_idx, pred_positions] masked_X masked_X.reshape((batch_size, num_pred_positions, -1)) mlm_Y_hat self.mlp(masked_X) return mlm_Y_hat这里最关键的是只把被 mask 的那些位置挑出来做预测因为 MLM 的损失只计算这些位置。没被 mask 的位置不参与 MLM 分类目标。所以这段代码本质上是在做从编码器输出里按位置索引抽取 mask 位置表示再送入 MLM 预测头16. 为什么pred_positions很重要pred_positions记录的是本条样本里哪些位置被选作 MLM 预测目标例如一句话长度 10可能只 mask 了第 2、5、8 个位置。那么pred_positions [2, 5, 8]这时 MLM 头就只处理这 3 个位置对应的表示。所以 BERT 预训练不只是把[MASK]塞进去这么简单还必须显式记录哪些位置需要参与 MLM loss17. NSP 头为什么更简单和 MLM 不同NSP 是一个句级二分类任务判断句子 B 是否是句子 A 的真实后续所以它只需要拿整段输入的一个综合表示做分类即可。在 BERT 里最经典做法就是用[CLS]位置对应的输出表示。因此 NSP 头通常很简单例如class NextSentencePred(nn.Module): def __init__(self, num_inputs, **kwargs): super(NextSentencePred, self).__init__(**kwargs) self.output nn.Linear(num_inputs, 2) def forward(self, X): return self.output(X)这里输出维度为 2表示是下一句不是下一句18. 为什么 NSP 通常只看[CLS]表示因为[CLS]的设计目标本来就是汇总整段输入的全局信息经过多层 Transformer 编码后[CLS]位置的表示通常已经融合了整句甚至句对的整体语义。所以拿它做文本分类句对判断NSP都很自然。这也是为什么 BERT 对句级任务特别方便。19. 最终 BERT 模型怎么组装李沐这里常见的总模型类大致会写成这样class BERTModel(nn.Module): def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len1000, key_size768, query_size768, value_size768, hid_in_features768, mlm_in_features768, nsp_in_features768): super(BERTModel, self).__init__() self.encoder BERTEncoder( vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_lenmax_len, key_sizekey_size, query_sizequery_size, value_sizevalue_size) self.hidden nn.Sequential(nn.Linear(hid_in_features, num_hiddens), nn.Tanh()) self.mlm MaskLM(vocab_size, num_hiddens, mlm_in_features) self.nsp NextSentencePred(nsp_in_features)这就把BERT 编码器MLM 头NSP 头全都组装到一起了。20. 为什么这里还有一个self.hidden这一句self.hidden nn.Sequential(nn.Linear(hid_in_features, num_hiddens), nn.Tanh())通常是为了对[CLS]表示先做一次额外变换再送入 NSP 分类头。也就是说编码器输出[CLS]表示经过一个小投影层再做 NSP 分类这和原始 BERT 的设计是对应的。它让句级分类头不只是一个裸线性层而是有一个中间变换。21. BERT 总模型前向传播怎么写常见写法如下def forward(self, tokens, segments, valid_lensNone, pred_positionsNone): encoded_X self.encoder(tokens, segments, valid_lens) if pred_positions is not None: mlm_Y_hat self.mlm(encoded_X, pred_positions) else: mlm_Y_hat None nsp_Y_hat self.nsp(self.hidden(encoded_X[:, 0, :])) return encoded_X, mlm_Y_hat, nsp_Y_hat这段代码把 BERT 的整体输出逻辑一下串起来了。22. 为什么返回三个结果这一点很关键return encoded_X, mlm_Y_hat, nsp_Y_hat说明 BERT 总模型不是只输出一个东西。encoded_X表示整段输入所有位置的上下文化表示。后续微调任务常常直接用它。mlm_Y_hat表示 MLM 任务在 mask 位置上的预测结果。预训练时要算 MLM loss。nsp_Y_hat表示 NSP 二分类结果。预训练时要算 NSP loss。所以 BERT 模型主体并不是某个单一任务模型而是一个共享编码器 多任务预训练头23. 为什么encoded_X[:, 0, :]就是[CLS]表示因为[CLS]通常被放在输入序列最前面。所以编码器输出张量encoded_X.shape (batch_size, num_steps, num_hiddens)中encoded_X[:, 0, :]就表示每个样本第 0 个位置也就是[CLS]的输出表示。这正好能作为整句综合表示用于 NSP 或其他句级任务。24. 这一节代码最该掌握什么如果从学习重点来看最关键的是下面几件事。24.1 BERT 输入表示由哪三部分组成token embeddingsegment embeddingposition embedding24.2 BERT 主体为什么是多层 Transformer Encoder这是它上下文化表示能力的来源。24.3 MLM 头和 NSP 头分别做什么MLMtoken 级预测NSP句级二分类24.4 为什么 MLM 只取被 mask 的位置因为预训练目标只在这些位置上计算。24.5 为什么[CLS]位置输出可用于句级任务因为它是整段输入的综合表示。25. 这一节和后面几节怎么衔接这一节其实只是把BERT 模型结构搭起来。而你给的目录后面还会继续讲BERT 预训练数据代码BERT 预训练代码BERT 微调自然语言推理数据集BERT 微调代码所以这节可以理解成先搭模型本体后面几节才会逐步补齐数据怎么造损失怎么算训练怎么跑下游任务怎么接这个顺序非常合理。26. 本节总结这一节我们学习了 BERT 的代码结构核心内容可以总结为以下几点。26.1 BERT 编码器由三部分输入表示和多层 Transformer Encoder 组成这是模型主体。26.2 输入表示包括 token、segment 和 position 三类 embedding三者相加形成最终输入。26.3 MLM 头负责预测被 mask 的 token因此只取被 mask 的位置做分类。26.4 NSP 头负责句级二分类通常使用[CLS]位置的输出表示。26.5 整个 BERT 模型本质上是“共享编码器 多任务预训练头”这是它预训练范式的核心。27. 学习感悟这一节特别重要因为它让你第一次真正看到BERT 并不是一个“神秘黑箱”它其实就是Transformer 编码器 精心设计的输入表示 预训练任务头。很多时候大家一提 BERT 就觉得它很大、很复杂。但如果把结构拆开其实非常清楚先把输入表示准备好用深层自注意力编码再接两个预训练任务头真正难的地方不是它概念多玄而是它把这些部分组织得特别好。

更多文章