Path: blob/main/transformers_doc/zh/tensorflow/summarization.ipynb
8775 views
摘要
加载 BillSum 数据集
首先从 🤗 Datasets 库中加载 BillSum 数据集中较小的加利福尼亚州法案子集:
使用 train_test_split 方法将数据集划分为训练集和测试集:
然后查看一个示例:
您会用到的两个字段是:
text:法案文本,将作为模型的输入。summary:text的精简版本,将作为模型的目标输出。
预处理
下一步是加载 T5 分词器,处理 text 和 summary:
您要创建的预处理函数需要:
在输入前添加提示词,让 T5 知道这是一个摘要任务。某些能够处理多种 NLP 任务的模型需要针对特定任务提示。
在对标签进行分词时使用关键字参数
text_target。将序列截断至不超过
max_length参数设置的最大长度。
使用 🤗 Datasets 的 map 方法将预处理函数应用于整个数据集。通过设置 batched=True 一次处理数据集的多个元素,可以加速 map 函数:
现在使用 DataCollatorForSeq2Seq 创建一批样本。在整理时将句子动态填充至批次中的最长长度,比将整个数据集填充至最大长度更高效。
评估
然后创建一个函数,将您的预测结果和标签传递给 compute 来计算 ROUGE 指标:
您的 compute_metrics 函数已准备就绪,在设置训练时会用到它。
训练
此时,只剩三个步骤:
在
Seq2SeqTrainingArguments中定义训练超参数。唯一必需的参数是output_dir,它指定保存模型的位置。通过设置push_to_hub=True,将模型推送到 Hub(您需要登录 Hugging Face 才能上传模型)。每个 epoch 结束时,Trainer将评估 ROUGE 指标并保存训练检查点。将训练参数传递给
Seq2SeqTrainer,同时传入模型、数据集、分词器、数据整理器和compute_metrics函数。调用
train()微调您的模型。
训练完成后,使用 push_to_hub() 方法将模型分享到 Hub,让所有人都能使用您的模型:
推断
很好,现在您已经微调了模型,可以用它进行推断了!
准备一些您想要生成摘要的文本。对于 T5,您需要根据所处理的任务为输入添加前缀。对于摘要任务,前缀如下所示:
对文本进行分词并将 input_ids 作为 PyTorch 张量返回:
使用 generate() 方法创建摘要。有关不同文本生成策略和控制生成参数的更多详情,请查阅文本生成 API。
将生成的词元 id 解码回文本: