Pytorch从0实现Transformer的实践

目录
  • 摘要
  • 一、构造数据
    • 1.1 句子长度
    • 1.2 生成句子
    • 1.3 生成字典
    • 1.4 得到向量化的句子
  • 二、位置编码
    • 2.1 计算括号内的值
    • 2.2 得到位置编码
  • 三、多头注意力
    • 3.1 self mask

摘要

With the continuous development of time series prediction, Transformer-like models have gradually replaced traditional models in the fields of CV and NLP by virtue of their powerful advantages. Among them, the Informer is far superior to the traditional RNN model in long-term prediction, and the Swin Transformer is significantly stronger than the traditional CNN model in image recognition. A deep grasp of Transformer has become an inevitable requirement in the field of artificial intelligence. This article will use the Pytorch framework to implement the position encoding, multi-head attention mechanism, self-mask, causal mask and other functions in Transformer, and build a Transformer network from 0.

随着时序预测的不断发展,Transformer类模型凭借强大的优势,在CV、NLP领域逐渐取代传统模型。其中Informer在长时序预测上远超传统的RNN模型,Swin Transformer在图像识别上明显强于传统的CNN模型。深层次掌握Transformer已经成为从事人工智能领域的必然要求。本文将用Pytorch框架,实现Transformer中的位置编码、多头注意力机制、自掩码、因果掩码等功能,从0搭建一个Transformer网络。

一、构造数据

1.1 句子长度

# 关于word embedding,以序列建模为例
# 输入句子有两个,第一个长度为2,第二个长度为4
src_len = torch.tensor([2, 4]).to(torch.int32)
# 目标句子有两个。第一个长度为4, 第二个长度为3
tgt_len = torch.tensor([4, 3]).to(torch.int32)
print(src_len)
print(tgt_len)

输入句子(src_len)有两个,第一个长度为2,第二个长度为4
目标句子(tgt_len)有两个。第一个长度为4, 第二个长度为3

1.2 生成句子

用随机数生成句子,用0填充空白位置,保持所有句子长度一致

src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L, )), (0, max(src_len)-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len])
print(src_seq)
print(tgt_seq)

src_seq为输入的两个句子,tgt_seq为输出的两个句子。
为什么句子是数字?在做中英文翻译时,每个中文或英文对应的也是一个数字,只有这样才便于处理。

1.3 生成字典

在该字典中,总共有8个字(行),每个字对应8维向量(做了简化了的)。注意在实际应用中,应当有几十万个字,每个字可能有512个维度。

# 构造word embedding
src_embedding_table = nn.Embedding(9, model_dim)
tgt_embedding_table = nn.Embedding(9, model_dim)
# 输入单词的字典
print(src_embedding_table)
# 目标单词的字典
print(tgt_embedding_table)

字典中,需要留一个维度给class token,故是9行。

1.4 得到向量化的句子

通过字典取出1.2中得到的句子

# 得到向量化的句子
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)

该阶段总程序

import torch
# 句子长度
src_len = torch.tensor([2, 4]).to(torch.int32)
tgt_len = torch.tensor([4, 3]).to(torch.int32)
# 构造句子,用0填充空白处
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(src_len)-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len])
# 构造字典
src_embedding_table = nn.Embedding(9, 8)
tgt_embedding_table = nn.Embedding(9, 8)
# 得到向量化的句子
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)

二、位置编码

位置编码是transformer的一个重点,通过加入transformer位置编码,代替了传统RNN的时序信息,增强了模型的并发度。位置编码的公式如下:(其中pos代表行,i代表列)

2.1 计算括号内的值

# 得到分子pos的值
pos_mat = torch.arange(4).reshape((-1, 1))
# 得到分母值
i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape((1, -1))/8)
print(pos_mat)
print(i_mat)

2.2 得到位置编码

# 初始化位置编码矩阵
pe_embedding_table = torch.zeros(4, 8)
# 得到偶数行位置编码
pe_embedding_table[:, 0::2] =torch.sin(pos_mat / i_mat)
# 得到奇数行位置编码
pe_embedding_table[:, 1::2] =torch.cos(pos_mat / i_mat)
pe_embedding = nn.Embedding(4, 8)
# 设置位置编码不可更新参数
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)
print(pe_embedding.weight)

三、多头注意力

3.1 self mask

有些位置是空白用0填充的,训练时不希望被这些位置所影响,那么就需要用到self mask。self mask的原理是令这些位置的值为无穷小,经过softmax后,这些值会变为0,不会再影响结果。

3.1.1 得到有效位置矩阵

# 得到有效位置矩阵
vaild_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len) - L)), 0)for L in src_len]), 2)
valid_encoder_pos_matrix = torch.bmm(vaild_encoder_pos, vaild_encoder_pos.transpose(1, 2))
print(valid_encoder_pos_matrix)

3.1.2 得到无效位置矩阵

invalid_encoder_pos_matrix = 1-valid_encoder_pos_matrix
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)
print(mask_encoder_self_attention)

True代表需要对该位置mask

3.1.3 得到mask矩阵
用极小数填充需要被mask的位置

# 初始化mask矩阵
score = torch.randn(2, max(src_len), max(src_len))
# 用极小数填充
mask_score = score.masked_fill(mask_encoder_self_attention, -1e9)
print(mask_score)

算其softmat

mask_score_softmax = F.softmax(mask_score)
print(mask_score_softmax)

可以看到,已经达到预期效果

到此这篇关于Pytorch从0实现Transformer的实践的文章就介绍到这了,更多相关Pytorch Transformer内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • python量化之搭建Transformer模型用于股票价格预测

    目录 前言 1.Transformer模型 2.环境准备 3.代码实现 3.1. 导入库以及定义超参 3.2. 模型构建 3.3. 数据预处理 3.4. 模型训练以及评估 3.5. 模型运行 4.总结 前言 下面的这篇文章主要教大家如何搭建一个基于Transformer的简单预测模型,并将其用于股票价格预测当中.原代码在文末进行获取. 1.Transformer模型 Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于

  • Pytorch从0实现Transformer的实践

    目录 摘要 一.构造数据 1.1 句子长度 1.2 生成句子 1.3 生成字典 1.4 得到向量化的句子 二.位置编码 2.1 计算括号内的值 2.2 得到位置编码 三.多头注意力 3.1 self mask 摘要 With the continuous development of time series prediction, Transformer-like models have gradually replaced traditional models in the fields of

  • PyTorch 1.0 正式版已经发布了

    PyTorch 1.0 同时面向产品化 AI 和突破性研究的发展,「我们在 PyTorch1.0 发布前解决了几大问题,包括可重用.性能.编程语言和可扩展性.」Facebook 人工智能副总裁 Jerome Pesenti 曾在PyTorch 开发者大会上表示. 随着 PyTorch 生态系统及社区中有趣新项目及面向开发者的教育资源不断增加,今天 Facebook 在 NeurIPS 大会上发布了 PyTorch 1.0 稳定版.该版本具备生产导向的功能,同时还可以获得主流云平台的支持. 现在,

  • vue3.0+vue-router+element-plus初实践

    Vue3中文文档 Vue3.0对比Vue2.x优势 框架内部做了大量的性能优化,包括:虚拟dom,编译模板,Proxy的新数据监听,更小的打包文件等. 新增的组合式API(即Composition API),更适合大型项目的编写方式. 对TypeScript支持更好,去掉this操作,更强大的类型推导. 初始化项目 安装@vue/cli npm install @vue/cli -g 或 yarn global add @vue/cli 创建项目 vue create 项目名 可以选择Vue3的

  • pytorch中[..., 0]的用法说明

    在看程序的时候看到了x[-, 0]的语句不是很理解,后来自己做实验略微了解,以此记录方便自己查看. b=torch.Tensor([[[[10,2],[4,5],[7,8]],[[1,2],[4,5],[7,8]]]]) print(b.size()) (1, 2, 3, 2) print(b[-,0]) tensor([[[10., 4., 7.], [ 1., 4., 7.]]]) print(b[-,0].size()) (1, 2, 3) print(b[-,2]) Traceback

  • sharding-jdbc5.0.0实现分表实践

    本文基于shardingsphere-jdbc-core-spring-boot-starter 5.0.0,请注意不同版本的sharding-jdbc配置可能有不一样的地方,本文不一定适用于其它版本 相关的maven配置如下:         <dependency>             <groupId>org.springframework.boot</groupId>             <artifactId>spring-boot-sta

  • Vuex2.0+Vue2.0构建备忘录应用实践

    一.介绍Vuex Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式.它采用集中式存储管理应用的所有组件的状态,并以相应的规则保证状态以一种可预测的方式发生变化,适合于构建中大型单页应用. 1.什么是状态管理模式? 看个简单的例子: <!DOCTYPE html> <html> <head> <meta charset="utf-8"> <meta name="viewport" content=&q

  • vue2.0 keep-alive最佳实践

    vue2.0 keep-alive的最佳实践,供大家参考,具体内容如下 1.基本用法 vue2.0提供了一个keep-alive组件用来缓存组件,避免多次加载相应的组件,减少性能消耗 <keep-alive> <component> <!-- 组件将被缓存 --> </component> </keep-alive> 有时候 可能需要缓存整个站点的所有页面,而页面一般一进去都要触发请求的 在使用keep-alive的情况下 <keep-al

  • 详解PyTorch手写数字识别(MNIST数据集)

    MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程.虽然网上的案例比较多,但还是要自己实现一遍.代码采用 PyTorch 1.0 编写并运行. 导入相关库 import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, t

  • pytorch + visdom CNN处理自建图片数据集的方法

    环境 系统:win10 cpu:i7-6700HQ gpu:gtx965m python : 3.6 pytorch :0.3 数据下载 来源自Sasank Chilamkurthy 的教程: 数据:下载链接. 下载后解压放到项目根目录: 数据集为用来分类 蚂蚁和蜜蜂.有大约120个训练图像,每个类有75个验证图像. 数据导入 可以使用 torchvision.datasets.ImageFolder(root,transforms) 模块 可以将 图片转换为 tensor. 先定义transf

  • windows系统快速安装pytorch的详细图文教程

    pip和conda的区别 之前一直使用conda和pip ,有时候经常会两者混用.但是今天才发现二者装的东西不是在一个地方的,所以发现有的东西自己装了,但是在运行环境的时候发现包老是识别不了,一直都特别疑惑,直到今天注意到这个问题,所以来总结一下二者的区别. pip pip专门管理Python包 编译源码中的所有内容. (源码安装) 由核心Python社区所支持(即,Python 3.4+包含可自动增强pip的代码). conda Python不可知论者. 现有软件包的主要重点是Python,而

随机推荐