0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

小白学大模型:大模型加速的秘密 FlashAttention 1/2/3

颖脉Imgtec ? 2025-09-10 09:28 ? 次阅读
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

本文转自:Coggle数据科学


在 Transformer 架构中,注意力机制的计算复杂度与序列长度(即文本长度)呈平方关系()。这意味着,当模型需要处理更长的文本时(比如从几千个词到几万个词),计算时间和所需的内存会急剧增加。最开始的标准注意力机制存在两个主要问题:

  1. 内存占用高:模型需要生成一个巨大的注意力矩阵 (N×N)。这个矩阵需要被保存在高带宽内存 (HBM)中。对于长序列,这很快就会超出 GPU 的内存容量。
  2. 计算效率低:标准实现会将注意力计算分解成多个独立的步骤(矩阵乘法、softmax 等)。每一步都需要将数据从速度较慢的 HBM 中读取,计算后又写回 HBM。这种频繁的数据移动(内存读写)成为了性能瓶颈,导致 GPU 的计算单元(如 Tensor Cores)利用率低下。

什么是 FlashAttention?

FlashAttention 使得处理长达数万甚至数十万个 token 的超长文本成为可能。这解锁了新的应用场景,例如分析法律文档、总结长篇小说或处理整个代码库。

FlashAttention 使得模型的训练和推理速度更快,尤其是在长序列场景下。例如,FlashAttention-2 在长序列上比标准实现快 10 倍,使得训练成本更低,用户体验更好。

最新的 FlashAttention-3 利用了新硬件(如 NVIDIA H100)的 FP8 精度,进一步提升了性能,同时通过特殊的算法保持了计算的准确性,让模型训练更加高效。

FlashAttention v1

许多研究提出了近似注意力方法,试图通过减少计算量(FLOPs)来提高效率。然而,这些方法通常忽略了GPU不同层级内存(如高速的片上SRAM和相对较慢的高带宽HBM)之间的I/O开销,导致它们在实际运行时并没有带来显著的加速。

77ce35b0-8de5-11f0-8ce9-92fbcf53809c.png

FlashAttention的核心思想是I/O感知,即在设计算法时,将数据在不同层级内存之间的读写开销考虑在内。论文指出,在现代GPU上,计算速度已经远超内存访问速度,因此大多数操作都受限于内存访问。FlashAttention通过以下两个关键技术来解决这一问题:

  • Tiling (平铺):将输入数据(Q、K、V矩阵)分割成小块,并在GPU的片上SRAM中进行计算。这样可以避免将庞大的 N×N 注意力矩阵完整地写入到速度较慢的HBM中。
  • 内存优化:在反向传播时,FlashAttention 不存储巨大的中间注意力矩阵,而是只保存前向传播中计算出的Softmax归一化因子。这样,反向传播时可以利用这些因子在SRAM中快速地重新计算注意力矩阵,从而避免了从HBM读取大矩阵的开销。

GPU内存层级

  • HBM (高带宽内存):容量大(如A100 GPU的40-80 GB),但速度相对较慢(带宽1.5-2.0 TB/s)。
  • 片上SRAM (静态随机存取存储器):容量小(每个流式多处理器有192 KB),但速度极快(带宽估计达19 TB/s),比HBM快一个数量级以上。

由于GPU的计算速度增长快于内存速度,许多操作的性能瓶颈在于内存访问,而不是计算本身。因此,如何高效利用快速的SRAM变得至关重要。

运算类型

根据算术强度(每字节内存访问的算术运算次数),操作可分为两类:

  • 计算密集型 (Compute-bound):运算时间由算术操作数量决定,内存访问时间相对较小。例如,大规模矩阵乘法。
  • 内存密集型 (Memory-bound):运算时间由内存访问次数决定,计算时间相对较小。例如,大多数元素级操作(如激活函数、Dropout)和归约操作(如Softmax、LayerNorm)。

注意力实现改进

77e48018-8de5-11f0-8ce9-92fbcf53809c.png

给定查询 Q、键 K 和值 V 矩阵,注意力的计算分三步:

  1. 相似度计算
  2. Softmax归一化
  3. 加权求和

标准实现(如“Algorithm 0”所示)将每一步都作为一个独立的GPU核函数,并物化(materialize)中间矩阵 S 和 P 到HBM中。

这种实现方式导致了两个主要问题:

  • 巨大的内存占用:中间矩阵 S 和 P 的大小为 N×N,其内存占用与序列长度 N 的平方成正比。
  • 大量的HBM访问:由于每个步骤都需要读写HBM,导致I/O开销巨大。论文指出,这种方法对HBM的访问次数是 O(N2) 级别的,这在长序列(通常 N?d)时会成为主要的性能瓶颈,导致运行时间慢。

77f375f0-8de5-11f0-8ce9-92fbcf53809c.png

FlashAttention旨在减少对GPU高带宽内存(HBM)的读写,实现对确切注意力(exact attention)的快速、内存高效的计算。为此,它采用了两种关键技术:

  1. Tiling(分块):将输入的 Q,K,V 矩阵分成若干小块。然后,在计算过程中,每次只将一小块数据从慢速的HBM加载到快速的片上SRAM进行计算,而不是一次性加载整个大矩阵。
  2. Recomputation(重计算):为了避免在反向传播时存储 O(N2) 的中间注意力矩阵 S 和 P,FlashAttention只存储 Softmax 的归一化统计量(即 m 和 ?)。在反向传播时,它会利用这些统计量,按需在SRAM中重新计算必要的注意力矩阵块。

通过Tiling和Recomputation,FlashAttention能够将所有计算步骤(矩阵乘法、Softmax、可选的遮蔽和Dropout)融合成一个单一的CUDA核函数。这避免了在每个步骤之间反复地将数据写入HBM。

实现效果

lashAttention在BERT-large模型上的训练速度超过了MLPerf 1.1的记录保持者。与Nvidia的实现相比,FlashAttention的训练时间缩短了15%,这证明了其在标准长序列任务上的卓越性能。

77fdef58-8de5-11f0-8ce9-92fbcf53809c.png

FlashAttention在训练GPT-2模型时,相比于流行的HuggingFace和Megatron-LM实现,实现了显著的端到端加速。

780c01e2-8de5-11f0-8ce9-92fbcf53809c.png

  • 与Huggingface相比,速度提升高达3倍
  • 与Megatron-LM相比,速度提升高达1.7倍
  • 重要的是,FlashAttention在不改变模型定义的情况下,实现了与基线模型相同的困惑度(perplexity),证明了其数值稳定性

在Long-Range Arena基准测试中,FlashAttention相比于标准的Transformer实现,实现了2.4倍的加速。此外,块稀疏FlashAttention的表现甚至优于所有已测试的近似注意力方法,证明了其在处理超长序列时的优越性。

lashAttention的内存占用与序列长度呈线性关系,而标准实现是平方关系。这使得FlashAttention的内存效率比标准方法高出20倍

FlashAttention v2

第一代FlashAttention通过利用 GPU 内存层次结构的特性,显著降低了内存占用(从二次方降为线性)并实现了 2-4 倍的加速,且没有引入任何近似。

然而,FlashAttention 的效率仍然不如优化的矩阵乘法(GEMM)操作,其浮点运算性能(FLOPs/s)仅能达到理论峰值的 25-40%。这主要是因为 FlashAttention 存在不优化的工作划分(work partitioning),导致 GPU 线程块(thread blocks)和线程束(warps)之间的并行度不足、占用率低或产生不必要的共享内存读写。

为了解决这些问题,论文提出了FlashAttention-2,通过以下改进实现了更好的工作划分:

  1. 减少非矩阵乘法(non-matmul)的浮点运算:虽然这类操作占总 FLOPs 的比例小,但执行起来很慢。
  2. 在序列长度维度上并行化:即使对于单个注意力头,也将其计算任务分配给不同的线程块,以提高 GPU 的占用率。
  3. 优化线程块内部的工作分配:在每个线程块内,重新分配线程束之间的工作,以减少通过共享内存进行的通信

前向传播改进

FlashAttention-2对在线 Softmax 技巧进行了两处微调:

78196f9e-8de5-11f0-8ce9-92fbcf53809c.png

  1. 延迟归一化:在每个循环迭代中,不立即对输出进行归一化。相反,它维护一个“未缩放”的中间结果,并在整个循环结束时仅进行一次最终的归一化。这减少了每个块的缩放操作,从而减少了非 matmul 的 FLOPs。
  2. 简化统计量:为反向传播存储数据时,只保存logsumexp统计量 L(j)=m(j)+log(?(j)),而不是同时存储最大值 m(j) 和指数和 ?(j)。

并行化改进

第一代 FlashAttention 仅在批处理大小和注意力头数量上进行并行化。当序列长度很长时,批处理大小通常很小,导致 GPU 资源的利用率(occupancy)不高。FlashAttention-2 通过在序列长度维度上增加并行化来解决这个问题。

78279600-8de5-11f0-8ce9-92fbcf53809c.png

  • 前向传播:FlashAttention-2 将注意力矩阵的行块任务分配给不同的线程块,这些线程块之间无需通信。通过在行维度上并行,当批次大小和注意力头数较小时,GPU 的 SM(流式多处理器)能够被更充分地利用,从而提高整体吞吐量。
  • 后向传播:类似地,后向传播则在注意力矩阵的列块上进行并行。由于反向传播中的某些更新需要跨线程块通信,作者使用了原子加法(atomic adds)来更新共享的梯度 dK 和 dV,确保了线程安全。

783873a8-8de5-11f0-8ce9-92fbcf53809c.png


除了线程块级别的并行,FlashAttention-2 还优化了线程块内部线程束之间的工作分配,以减少共享内存的读写。

  • 前向传播

    • FlashAttention:采用“split-K”方案,将 K 和 V 矩阵的计算任务分配给不同的线程束。这要求所有线程束将中间结果写入共享内存,再进行同步和求和,导致不必要的共享内存访问。
    • FlashAttention-2:改为将 Q 矩阵的计算任务分配给不同的线程束。每个线程束负责计算 Q 的一个分片与完整的 K 的乘积。这样,每个线程束可以独立地完成其部分输出,而无需与其他线程束进行共享内存通信,从而显著提高了效率。
  • 后向传播:后向传播的依赖关系更复杂,但 FlashAttention-2 仍然通过避免“split-K”方案来减少共享内存的读写,实现了性能提升。

实现效果

FlashAttention-2 比第一代 FlashAttention 快1.7-3.0 倍,比 Triton 实现的 FlashAttention 快1.3-2.5 倍

7847a2ba-8de5-11f0-8ce9-92fbcf53809c.png


在 A100 GPU 上,FlashAttention-2 在前向传播中达到了230 TFLOPs/s的峰值,相当于理论最大吞吐量的73%。在后向传播中,它达到了理论最大吞吐量的 63%。

FlashAttention v3

虽然之前的 FlashAttention 通过减少内存读写来加速计算,但它未能充分利用现代硬件(如 Hopper GPU)的新特性。例如,FlashAttention-2 在 H100 GPU 上的利用率仅为 35%。

与 FlashAttention-2 类似,FlashAttention-3 也将任务并行化到不同的线程块(CTA),但其创新之处在于在单个线程块内部,将线程束(warps)划分为不同的角色。

  • 生产者(Producer):负责将数据从 HBM(全局内存)异步加载到 SMEM(共享内存)。
  • 消费者(Consumer):在数据加载完成后,从 SMEM 读取数据并执行计算。

生产者和消费者通过一个循环缓冲区(circular buffer)进行同步。生产者将数据放入缓冲区,消费者从中取出。当缓冲区中的一个“阶段”被消费后,生产者就可以继续向其中加载新数据。

线程内部的 GEMM 和 Softmax 重叠

在标准 FlashAttention 中,GEMM 和 Softmax 存在顺序依赖:Softmax 必须在第一个 GEMM 计算完成后才能开始,而第二个 GEMM 必须等待 Softmax 的结果。

785910cc-8de5-11f0-8ce9-92fbcf53809c.png7863332c-8de5-11f0-8ce9-92fbcf53809c.png


FlashAttention-3 通过在寄存器使用额外的缓冲区,打破了这种依赖关系。在每次循环中,它异步启动下一个 GEMM 的计算,而同时执行当前 GEMM 结果的 Softmax 和更新操作。这样,GEMM 和 Softmax 的执行就可以重叠,提高了效率。

FP8 低精度计算

FP8 的 WGMMA(Warp Group Matrix-Multiply-Accumulate)指令要求输入矩阵具有特定的k-major 布局,而输入张量通常是mn-major 布局

786d690a-8de5-11f0-8ce9-92fbcf53809c.png

FlashAttention-3 选择在 GPU 内核中(in-kernel)进行转置。它利用 LDSM/STSM 指令,这些指令能够高效地在 SMEM 和 RMEM(寄存器)之间进行数据传输,并在传输过程中完成布局转置,避免了代价高昂的 HBM 读写。

同于传统的逐张量(per-tensor)量化,FlashAttention-3 对每个进行单独量化。这使得每个块可以有自己的缩放因子,从而更有效地处理离群值,减少量化误差。

实现效果

FlashAttention-3 的前向传播速度比 FlashAttention-2 快1.5-2.0 倍,后向传播快1.5-1.75 倍。FP16 版本的 FlashAttention-3 达到了740 TFLOPs/s的峰值,相当于 H100 GPU 理论最大吞吐量的 **75%**。

787c0cda-8de5-11f0-8ce9-92fbcf53809c.png

在处理中长序列(1k 及以上)时,FlashAttention-3 的性能甚至超过了 NVIDIA 自家闭源、针对 H100 优化的cuDNN库。


声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 内存
    +关注

    关注

    8

    文章

    3143

    浏览量

    75588
  • 人工智能
    +关注

    关注

    1811

    文章

    49355

    浏览量

    253597
  • 大模型
    +关注

    关注

    2

    文章

    3289

    浏览量

    4385
收藏 人收藏
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

    评论

    相关推荐
    热点推荐

    3d模型问题

    最近在学3d模型,遇到些问题,请教大家。1.3D模型的格式。模型是用3dmax做的,为3ds格式
    发表于 11-12 17:14

    AD的3D模型绘制功能介绍

    `  首先,在封装库的编辑界面下,我们点击菜单栏目的Place-》3D Body,见图(1)。    图(13D模型打开步骤  打开后就会
    发表于 01-14 16:48

    MRAS模型和可调模型参考

    1、简写MRAS参考模型和可调模型参考模型和可调模型方程:简写为如下形式:参考模型:可调
    发表于 08-27 06:44

    压缩模型加速推理吗?

    位压缩和“无”配置下都运行了 115 毫秒,尽管精度有所下降。我认为将 float 网络参数压缩为 uint8_t 不仅可以节省内存,还可以加快推理速度。那么,压缩模型是否应该加速推理?
    发表于 01-29 06:24

    3D模型基础

    1. 无论是人物、场景还是特效粒子系统等, 归根到底都是3D模型2. 导入一个人物模型, 点击Scene场景的Shaded按钮, 选择渲
    发表于 03-03 06:08 ?13次下载
    <b class='flag-5'>3</b>D<b class='flag-5'>模型</b>基础

    LTC2175/4/3/2/1/0 IBIS模型

    LTC2175/4/3/2/1/0 IBIS模型
    发表于 04-10 14:38 ?2次下载
    LTC2175/4/<b class='flag-5'>3</b>/<b class='flag-5'>2</b>/<b class='flag-5'>1</b>/0 IBIS<b class='flag-5'>模型</b>

    小白开始RTOS 1

    小白qi开始RTOS 1前言一、不知道从什么地方开始学习二、使用步骤1.引入库2.读入数据总结前言一、pandas是什么?二、使用步骤
    发表于 12-03 09:51 ?0次下载
    <b class='flag-5'>小白</b>开始<b class='flag-5'>学</b>RTOS <b class='flag-5'>1</b>

    如何改进和加速扩散模型采样的方法1

      尽管扩散模型实现了较高的样本质量和多样性,但不幸的是,它们在采样速度方面存在不足。这限制了扩散模型在实际应用中的广泛采用,并导致了从这些模型加速采样的研究领域的活跃。在 Part
    的头像 发表于 05-07 14:25 ?2875次阅读
    如何改进和<b class='flag-5'>加速</b>扩散<b class='flag-5'>模型</b>采样的方法<b class='flag-5'>1</b>

    自动驾驶车辆控制(车辆运动模型

    本文应配合b站up主“ 忠厚老实的老王 ”的 自动驾驶控制算法 系列视频食用。文章目录1. 两个车辆运动模型 1.1 三个坐标系 1.2符号定义 1.3车辆运动
    发表于 06-07 11:53 ?0次下载
    自动驾驶车辆控制(车辆运动<b class='flag-5'>学</b><b class='flag-5'>模型</b>)

    如何加速生成2 PyTorch扩散模型

    加速生成2 PyTorch扩散模型
    的头像 发表于 09-04 16:09 ?1582次阅读
    如何<b class='flag-5'>加速</b>生成<b class='flag-5'>2</b> PyTorch扩散<b class='flag-5'>模型</b>

    加速度传感器的基本力学模型是什么

    加速度传感器的基本力学模型是一个受力物体的运动和动力学模型的组合。本文将从以下几个方面介绍加速度传感器的基本力学
    的头像 发表于 01-17 11:08 ?2125次阅读

    写给小白的大模型入门科普

    什么是大模型?大模型,英文名叫LargeModel,大型模型。早期的时候,也叫FoundationModel,基础模型。大模型是一个简称。完
    的头像 发表于 11-23 01:06 ?823次阅读
    写给<b class='flag-5'>小白</b>的大<b class='flag-5'>模型</b>入门科普

    小白模型:训练大语言模型的深度指南

    在当今人工智能飞速发展的时代,大型语言模型(LLMs)正以其强大的语言理解和生成能力,改变着我们的生活和工作方式。在最近的一项研究中,科学家们为了深入了解如何高效地训练大型语言模型,进行了超过
    的头像 发表于 03-03 11:51 ?927次阅读
    <b class='flag-5'>小白</b><b class='flag-5'>学</b>大<b class='flag-5'>模型</b>:训练大语言<b class='flag-5'>模型</b>的深度指南

    小白模型:从零实现 LLM语言模型

    在当今人工智能领域,大型语言模型(LLM)的开发已经成为一个热门话题。这些模型通过学习大量的文本数据,能够生成自然语言文本,完成各种复杂的任务,如写作、翻译、问答等。https
    的头像 发表于 04-30 18:34 ?791次阅读
    <b class='flag-5'>小白</b><b class='flag-5'>学</b>大<b class='flag-5'>模型</b>:从零实现 LLM语言<b class='flag-5'>模型</b>

    小白模型:国外主流大模型汇总

    )领域。论文的核心是提出了一种名为Transformer的全新模型架构,它完全舍弃了以往序列模型(如循环神经网络RNNs和卷积神经网络CNNs)中常用的循环和卷积结构
    的头像 发表于 08-27 14:06 ?173次阅读
    <b class='flag-5'>小白</b><b class='flag-5'>学</b>大<b class='flag-5'>模型</b>:国外主流大<b class='flag-5'>模型</b>汇总