仅用于站内搜索,没有排版格式,具体信息请跳转上方微信公众号内链接
点击上方“小白学视觉”,选择加\“星标\“或“置顶”
重磅干货,第一时间送达
大家好,我是吴师兄。
前天写了一篇文章为什么我还是无法理解Transformer?,阅读和转发量都挺高的,索性继续这个话题,聊聊Attention。
Attention模块明明只有一个公式:
但为什么看起来像黑魔法一样玄学?
今天我们就来一次「去魅化」,从反向传播、矩阵计算、参数更新的角度,聊聊Attention到底哪里容易卡住,以及如何搞懂它。
众所周知,Q(Query)、K(Key)、V(Value)是Attention模块的输入。
但真正的“坑”在于:很多讲解告诉你“Q去和K计算相似度,再用这个相似度加权V”,就戛然而止了,完全没有解释——
Q、K、V是怎么来的?参数是怎么更新的?它真的可以被训练吗?
实际上我们在Transformer的实现中,都会看到类似这段代码:
乍一看只有几行代码,但里面全是门道。
每一个Linear层其实都是含参数的全连接层,q_linear,k_linear,v_linear,o_linear都有自己的W和b——这些参数就是通过反向传播训练出来的!
很多人会问:
“
“Transformer好像只有Attention操作,没看到什么网络结构,它能训练吗?”
答案是肯定的,Attention本质上只是一个“可导的矩阵运算块”。
Transformer=多层堆叠的子结构,每层的基本结构如下:
Attention里面用到的QK^T/sqrt(d_k)是标准矩阵运算
softmax也是可导函数
最后的加权求和⋅V仍然是矩阵乘法
也就是说:整个Attention模块组成了一条可导路径,在PyTorch里每一行都是可以backward()的!
如果你还不太习惯公式,那我们换一种方式来讲Attention的本质:
“
“拿Q去跟每个K计算相似度,然后根据相似度去加权V。”
具体过程:
输入被映射为Q、K、V(都是线性层带参数的)
对Q和K做点积,得到一个相关性分数矩阵
用softmax把分数转换为概率(即注意力权重)
用这些权重对V做加权求和
整个流程可以总结为:“带权重的加权平均”。
Transformer能训练,靠的是全图自动微分。
比如你这样写:
那下面这些Linear层都会自动更新:
q_linear/k_linear/v_linear/o_linear
FFN中的两层全连接
甚至Embedding层的参数
Attention本身(矩阵乘法+softmax+加权和)只是中间计算节点,每一步都支持链式求导,没有中断。
换句话说:
“
你在前向传播里写的“Attention”,会在反向传播时被“拆解成小操作”逐层求导!
之前我也写过一篇文章,也被很多大厂面试问到:
“
Transformer的Attention为什么要除以√dₖ?
被问懵了!Transformer的Attention为什么要除以√dₖ?
原因主要有两个:
避免梯度消失:随着dₖ增大,点积结果变大,softmax输出趋于极端→梯度接近0加个缩放项能让softmax更“温柔”,保持梯度稳定。
让不同长度、不同维度的输入分布更一致就像做标准化一样,是为了数值稳定性。
一句话总结:这是训练稳定性的工程trick,不能少!
“
相对位置编码(RoPE)是用在Q和K上的,它不是直接加,而是旋转。
简单来说:
在计算QK^T之前,先对Q和K做「带位置信息的旋转变换」
这种旋转是根据token的相对位置来的
用的是复数或旋转矩阵的技巧
不需要新增参数,且天然支持长序列外推
位置变了→向量的“朝向”就变了→相似度也变了
这是不是很优雅!
网上很多讲解:
要么只讲公式,没结合代码和反向传播
要么只解释功能(什么让每个词关注别的词),却不说原理
要么一堆比喻,但没有让你“跑起来一遍”
我建议你从训练一个简化版的Attention开始,跑一遍+打一遍断点,所有的黑盒就都变成白盒了。
如果你也曾被Attention搞晕,不妨记住这几点:
Attention是网络,不是trick,它完全可以训练
Q、K、V来自Linear层,带参数,可以反向传播
softmax、点积、矩阵乘法都可导,是链式计算图的一部分
Transformer的训练和CNN一样,靠的是loss. backward()
除以√dₖ是为了训练稳定
RoPE是在Q、K上做的位置旋转,帮助模型建模相对位置信息
别再害怕Attention,它不是魔法,只是数学。
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三+上海交大+视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~