基于预训练模型 ERNIE-Gram 实现语义匹配-模型搭建

本章继续分析该案例的模型搭建部分
自从 2018 年 10 月以来,NLP 个领域的任务都通过 Pretrain + Finetune 的模式相比传统 DNN 方法在效果上取得了显著的提升,本节我们以百度开源的预训练模型 ERNIE-Gram 为基础模型,在此之上构建 Point-wise 语义匹配网络。

前言:

首先,先了解一下Pretrain+Finetune是什么
预训练(Pretrain):当我们想要搭建一个网络模型来完成一个任务时,需要初始化参数,然后开始训练网络,不断减小损失,最后得到优秀的参数,把这些参数保存起来,以便于下次直接使用。
微调 (Finetune):使用别人的参数,为了适应自己的数据集,需要对这些参数进行微调。

具体步骤:

基于 ERNIE-Gram 模型结构搭建神经网络,首先先把 ERNIE-Gram 模型导入进来

import paddle.nn as nn
pretrained_model = paddlenlp.transformers.ErnieGramModel.from_pretrained('ernie-gram-zh')

看一下pretrained_model具体的内容

可以初步看到,这个模型有12层网络以及embeddings的维度,embeddings里含有权重参数,来具体查看一下word_embeddings的weight

接下来,我们定义自己的神经网络

class PointwiseMatching(nn.Layer):

    # 此处的 pretained_model 在本例中会被 ERNIE-Gram 预训练模型初始化
    def __init__(self, pretrained_model, dropout=None):
        super().__init__()
        self.ptm = pretrained_model
        self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)

        # 语义匹配任务: 相似、不相似 2 分类任务
        self.classifier = nn.Linear(self.ptm.config["hidden_size"], 2)

    def forward(self,
                input_ids,
                token_type_ids=None,
                position_ids=None,
                attention_mask=None):
        # 此处的 Input_ids 由两条文本的 token ids 拼接而成
        # token_type_ids 表示两段文本的类型编码
        # 返回的 cls_embedding 就表示这两段文本经过模型的计算之后而得到的语义表示向量
        _, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids,
                                    attention_mask)


        cls_embedding = self.dropout(cls_embedding)

        # 基于文本对的语义表示向量进行 2 分类任务
        logits = self.classifier(cls_embedding)
        probs = F.softmax(logits)

        return probs

自己定义的函数表面上看上去有四层:pretrained_model层、Dropout层(通过防止过拟合来减少过拟合的正则化技术)、Linear层(进行 2 分类任务)、softmax(激活函数)
而实际上,pretrained_model模型里还含了12层,我们也是基于这个模型进行第一步计算,而计算的结果我找了一下源码发现返回两个值:sequence_output(模型最后一层的隐藏状态序列)和 pooled_output(序列中第一个标记(’ [CLS] ‘)的输出)

该任务只需要pooled_output,用cls_embedding把它保存,因为之前定义了batch_size=32,而hidden_size的值通过查看源码是768,因此cls_embedding的维度是[32,768]
然后我们对cls_embedding防止过拟合以及二分类,最后softmax一下,前向计算过程结束,返回结果probs

# 定义 Point-wise 语义匹配网络
model = PointwiseMatching(pretrained_model)

模型搭建到此结束,接下来就是训练了。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
乘风的头像乘风管理团队
上一篇 2022年5月18日
下一篇 2022年5月18日

相关推荐