“本文转载自 Henry Ko 的 Blog,深入地解析了 TPU 的相关知识。”
我最近一直在大量使用 TPU,很有趣地发现它们与 GPU 的设计理念有多么不同。
TPU 的主要优势在于其可扩展性。这是通过硬件(例如,能效和模块化)和软件(例如,XLA 编译器)的协同设计实现的。
背景
简单来说,TPU 是谷歌的专用集成电路 (ASIC),专注于两个因素:极高的矩阵乘法吞吐量 + 高能效。
它们的起源可以追溯到 2006 年的谷歌,当时他们首次评估应该采用 GPU、FPGA 还是定制 ASIC。那时只有少数应用需要专门的硬件,他们认为这些需求可以通过利用其大型数据中心多余的 CPU 计算能力来满足。但情况在 2013 年发生了变化,当时谷歌的语音搜索功能开始在神经网络上运行,内部预测显示,如果该功能大受欢迎,他们将需要更多的计算能力。
快进到今天,TPU 为谷歌的大部分 AI 服务提供了动力。当然,这包括 Gemini 或 Veo 的训练和推理,也包括部署他们的推荐模型 (DLRM)。
让我们从底层向上深入了解 TPU 的内部结构。
TPU 单芯片层面
我的图表将主要关注 TPUv4,但这种布局或多或少也适用于最新一代的 TPU(例如,TPUv6p "Trillium";截至 2025 年 6 月撰写本文时,TPUv7 "Ironwood" 的细节尚未公布)。
以下是单个 TPUv4 芯片的布局:

在每个芯片中,有两个 TPU TensorCore,它们负责计算。(注意:专用于推理的 TPU 只有一个 TensorCore)。两个 TensorCore 共享内存单元:CMEM (128MiB) 和 HBM (32GiB)。
在每个 TensorCore 内部,是我们的计算单元和更小的内存缓冲区:
-
矩阵乘法单元 (Matrix Multiply Unit, MXU)
-
这是 TensorCore 的关键组件,是一个 128x128 的脉动阵列 (systolic array)。
-
我们将在下面介绍脉动阵列。
-
向量单元 (Vector Unit, VPU)
-
用于通用的逐元素操作(例如 ReLU、逐点加/乘、规约操作)。
-
向量内存 (Vector Memory, VMEM; 32MiB)
-
内存缓冲区。数据在 TensorCore 进行任何计算之前,会从 HBM 复制到 VMEM 中。
-
标量单元 + 标量内存 (Scalar Unit + SMEM; 10MiB)
-
告诉 VPU 和 MXU 该做什么。
-
管理控制流、标量操作和内存地址生成。
如果你熟悉 NVIDIA GPU,可能会对一些初步观察感到困惑:
-
TPU 上的片上内存单元(CMEM、VMEM、SMEM)比 GPU 上的 L1、L2 缓存大得多。
-
TPU 上的 HBM 也比 GPU 上的 HBM 小得多。
-
负责计算的“核心”似乎要少得多。
这与 GPU 的情况正好相反,GPU 拥有较小的 L1、L2 缓存(H100 分别为 256KB 和 50MB)、更大的 HBM(H100 为 80GB)以及数以万计的核心。
在我们进一步讨论之前,请记住 TPU 能够像 GPU 一样实现极高的吞吐量。TPU v5p 每芯片可达到 500 TFLOPs/秒,一个包含 8960 个芯片的完整 Pod 大约可以达到 4.45 ExaFLOPs/秒。据说最新的 "Ironwood" TPUv7 每个 Pod(9216 chips)最高可达 42.5 ExaFLOPS/秒。
要理解 TPU 是如何实现这一点的,我们需要了解它们的设计理念。
TPU 设计理念
TPU 依靠两大支柱和一个关键假设,实现了惊人的吞吐量和能效:脉动阵列 + 流水线技术、预先编译 (Ahead-of-Time, AoT),以及假设大多数操作都可以用一种能很好地映射到脉动阵列的方式来表达。幸运的是,在我们现代的深度学习时代,矩阵乘法占据了计算的大部分,这非常适合脉动阵列。
TPU 设计选择 #1: 脉动阵列 + 流水线技术
问:什么是脉动阵列 (Systolic Array)?
脉动阵列是一种硬件设计架构,由一个网格状的互连处理单元 (Processing Element, PE) 组成。每个 PE 执行一个小的计算(例如,乘法和累加),并将结果传递给相邻的 PE。

这种设计的好处在于,一旦数据被送入脉动阵列,就不再需要额外的控制逻辑来决定如何处理数据。此外,当脉动阵列足够大时,除了输入和输出,没有其他的内存读/写操作。
由于其刚性的组织结构,脉动阵列只能处理具有固定数据流模式的操作,但幸运的是,矩阵乘法和卷积完美地契合了这一范畴。
此外,通过流水线技术,可以将计算与数据移动重叠起来。下面是一个在 TPU 上进行流水线式逐点操作的图示。
题外话:脉动阵列的缺点 - 稀疏性
你可以看到脉动阵列非常喜欢密集矩阵(即每个 PE 在几乎每个周期都处于活动状态)。然而,其缺点在于,对于同样大小的稀疏矩阵,性能没有提升:即使是对于值为零的元素,PE 仍然需要执行相同数量的周期来进行计算。
如果深度学习社区更青睐不规则的稀疏性(例如 MoE),那么处理脉动阵列的这种系统性限制将变得更加重要。
TPU 设计选择 #2: 预先编译 (AoT) + 减少对缓存的依赖
本节回答了 TPU 如何通过TPU + XLA 编译器的软硬件协同设计来避免使用缓存,从而实现高能效。
首先,回想一下,传统缓存是为了处理不可预测的内存访问模式而设计的。一个应用程序的程序与另一个应用程序的程序可能有着截然不同的内存访问模式。本质上,缓存使硬件具有灵活性,能够适应各种应用。这是 GPU 成为非常灵活的硬件的一大原因(注:与 TPU 相比)。
然而,缓存访问(以及一般的内存访问)会消耗大量能量。下面是一个芯片上操作的能耗粗略估算(45nm, 0.9V; )。这里的关键信息是,内存访问和控制占据了我们大部分的能量,而算术运算的能耗则要低得多。

但是,如果你的应用非常特定,其计算/内存访问模式高度可预测呢?
举个极端的例子,如果我们的编译器能够提前计算出所有需要的内存访问,那么我们的硬件只需要一个便签式内存 (scratchpad memory) 作为缓冲区就足够了,完全不需要缓存。
这正是 TPU 理念所追求的,也正是为什么 TPU 与 XLA 编译器协同设计以实现这一目标。XLA 编译器通过提前分析计算图来生成优化的程序。
问:但是 JAX 也能很好地与 TPU 配合,可它们用的是 @jit?
JAX+XLA 在 TPU 上处于即时编译 (JIT) 和预先编译 (AOT) 的混合状态,因此会产生困惑。当我们第一次在 JAX 中调用一个 jit 函数时,JAX 会追踪它以创建一个静态计算图。这个图被传递给 XLA 编译器,在那里它被转换成一个完全静态的 TPU 二进制文件。正是在这最后的转换阶段,会进行 TPU 特定的优化(例如,最小化内存访问)来为 TPU 定制处理过程。
但有一个注意事项:如果 jit 函数以不同的输入形状运行,就必须重新编译和缓存。这就是为什么 JAX 在处理任何动态填充或依赖于输入的不同长度的 for 循环层时表现不佳的原因。
当然,这种方法听起来很好,但也有不便的缺点。它缺乏灵活性,并且对编译器的重度依赖是一把双刃剑。
但为什么谷歌仍然坚持这种设计理念呢?
TPU 与能效 (TPUv4)
前面那张能耗图并不能准确代表 TPU,所以这里是 TPUv4 的能耗分解。请注意,TPUv4 是 7nm 工艺,而 45nm 的数据仅作对比 。

左边的条形图直观地显示了数值,但需要注意的一点是,现代芯片使用 HBM3,其能耗远低于这里显示的 DDR3/4 DRAM。尽管如此,这表明内存操作的能耗要高出几个数量级。
这与现代规模法则 (scaling laws) 有很好的联系:我们非常乐意增加浮点运算次数 (FLOPS) 来换取内存操作的减少。因此,减少内存操作具有双重的优化效益,因为它们不仅使程序运行更快,而且能耗也更低。
TPU 多芯片层面
让我们再一层,看看 TPU 在多芯片环境中的工作方式。
Tray 层面 (又称 "板卡"; 4 芯片)

一个 TPU tray 包含 4 个 TPU 芯片或 8 个 TensorCore(简称为“核心”)。每个 tray 都有自己的 CPU 主机(注意:对于推理型 TPU,一个主机访问 2 个tray,因为它们每个芯片只有 1 个核心)。
主机与芯片的连接是 PCIe,但芯片与芯片之间的连接是核间互连 (Inter-Core Interconnect, ICI),它具有更高的带宽。
但 ICI 连接可以延伸到更远的多个托盘。为此,我们需要上升到机架 (Rack) 层面。
机架 (Rack) 层面 (4x4x4 芯片)
TPU 特别令人兴奋的部分在于其可扩展性,我们从机架层面开始看到这一点。
一个 TPU 机架由 64 个 TPU 组成,它们以 4x4x4 的 3D 环面 (torus) 结构连接。如果你看过谷歌下面这样的 TPU 宣传材料,那就是 8 个 TPU 机架的图像。

但在我们深入了解机架之前,需要澄清一些容易混淆的术语:机架 (rack) vs. Pod vs. Slice。
问:“TPU Rack”、“TPU Pod” 和 “TPU Slice” 之间有什么区别?
不同的谷歌资料对它们的使用略有不同,有时会将 "TPU Pods" 和 "TPU Slices" 互换使用。但在本文中,我们将遵循谷歌 TPU 论文和 GCP TPU 文档中的定义。
-
TPU Rack (机架):
-
包含 64 个芯片的物理单元。也称为“立方体 (cube)”。
-
-
TPU Pod:
-
可通过 ICI 和光纤连接的 TPU 的最大单元。
-
也称为 "Superpod" 或 "Full pod"。例如,TPUv4 的一个 TPU Pod 将由 4096 个芯片或 64 个 TPU 机架组成。
-
-
TPU Slice (切片):
-
介于 4 个芯片和 Superpod 大小之间的任何 TPU 配置。
-
关键区别在于,TPU Rack 和 TPU Pod 是物理度量单位,而 TPU Slice 是一个抽象单位。当然,设置 TPU Slice 有重要的物理属性,但我们暂时将其抽象化。
现在,我们将使用物理度量单位:TPU Racks 和 TPU Pods。这是因为了解 TPU 系统是如何物理连接的,可以帮助我们更好地理解 TPU 的设计理念。
现在回到 TPU 机架 (针对 TPUv4):
一个 TPU 机架由 64 个芯片组成,通过 ICI 和光路交换 (Optical Circuit Switching, OCS) 连接在一起。本质上,我们连接多个 tray 来模拟一个 64 芯片的系统。这种将小部件组合成超级计算机的主题在后面会继续出现。
下面是单个 TPUv4 机架的图示。它是一个 4x4x4 的 3D 环面,每个节点是一个芯片,蓝色的箭头是 ICI,而各个面上的线是 OCS。

然而,这张图引出了几个问题。为什么 OCS 只用于各个面?换句话说,使用 OCS 有什么好处?有 3 大好处,我们稍后会介绍另外两个。
OCS 的好处 #1: 环绕连接 (Wraparound)
通过环绕连接实现节点间更快的通信。
OCS 还充当给定 TPU 配置的环绕连接。这将两个节点之间的最坏情况跳数从 N-1 跳减少到每个轴 (N-1)/2 跳,因为每个轴都变成了一个环(1D 环面)。
随着我们进一步扩展,这种效应变得更加重要,因为减少芯片间通信延迟对于高并行化至关重要。
题外话:并非所有 TPU 都具有 3D 环面拓扑
注意:较早的 TPU 代(例如 TPUv2, v3)和推理 TPU(例如 TPUv5e, TPUv6e)具有 2D 环面拓扑,而不是像下面这样的 3D 环面。然而,TPUv7 "Ironwood" 似乎是 3D 环面,尽管它被宣传为推理芯片(注意:我只是根据他们的宣传材料进行假设)。

完整 Pod 层面 (又称 "Superpod"; TPUv4 为 4096 芯片)
就像我们将多个芯片连接起来组成一个 TPU 机架一样,我们可以连接多个机架来组成一个大型的 Superpod。
Superpod 也指 TPU 可以达到的(仅使用 ICI 和 OCS)互连芯片的最大配置。接下来还有一个多 Pod 层面,但这必须通过较慢的互连,我们稍后会讨论。
这个大小因代而异,但对于 TPUv4 是 4096 个芯片(即 64 个 4x4x4 芯片的机架)。对于最新的 TPUv7 "Ironwood",是 9216 个芯片。
下图显示了一个 TPUv4 的 Superpod。

请注意每个立方体(即一个 TPU 机架)是如何通过 OCS 相互连接的。这也允许我们在一个 Pod 中获取 TPU 的“切片”(slices)。
带 OCS 的 TPU 切片
我们可以在 Pod 内请求 TPU 的子集,这些就是 TPU 切片。但即使你想要 N 个芯片,也有多种拓扑可供选择。
例如,假设你总共需要 512 个芯片。你可以要求一个立方体 (cube) (8x8x8)、一个雪茄形 (cigar shape) (4x4x32) 或一个矩形 (rectangle) (4x8x16)。选择切片的拓扑本身就是一个超参数。
你选择的拓扑会影响节点之间的通信带宽。这直接影响不同并行化方法的性能。
例如,对于全局通信 (all-to-all),如数据并行或张量并行,立方体 (例如 8x8x8) 会是首选,因为它具有最高的对分带宽 (bisection bandwidth)。然而,对于流水线并行,雪茄形 (例如 4x4x32) 会更好,因为它可以更快地与顺序层通信(假设一个层适合一个 4x4 芯片的子切片)。

当然,最佳拓扑取决于模型,找到它本身就是一项工作。TPUv4 的论文 [9] 也对此进行了测量,以显示拓扑变化如何加速吞吐量(注意:我不确定第一行指的是哪种 LLM 架构,因为它没有具体说明)。

我们介绍了 TPU 切片,但还有一个重要特性有助于 TPU 的高运行稳定性。
那就是由于 OCS,这些切片不必是连续的机架。这是我们前面没有提到的使用 OCS 的第二个好处——可能也是最大的好处。
OCS 的好处 #2: (可重构的) 非连续多节点切片
请注意,这与硬连线多个节点来模拟非连续切片是不同的。由于 OCS 是一个交换机而不是硬连线,节点之间的物理线路要少得多,因此允许更高的可扩展性(即更大的 TPU Pod 尺寸)。
这允许大规模灵活的节点配置。例如,假设我们想在一个 Pod 上运行三个作业。虽然朴素的调度不允许这样做,但 OCS 连接允许我们抽象出节点的位置,并将整个 Pod 仅仅看作一个 “节点袋” (bag of nodes)。

这提高了 Pod 的利用率,并且在节点发生故障时可能使维护更容易。谷歌将其描述为**“故障节点的爆炸半径很小”。然而,我不确定当只有某些节点必须关闭时,其液体冷却会受到怎样的影响。
最后,这种灵活的 OCS 还有一个有趣的扩展:我们还可以改变 TPU 切片的拓扑,例如从常规环面变为扭曲环面 (twisted torus)。
OCS 的好处 #3: 扭曲的 TPU 拓扑
我们之前看到了如何通过改变固定芯片数量的 (x,y,z) 维度来获得不同的 TPU 切片拓扑。然而,这次我们将在固定的 (x,y,z) 维度下工作,但改变它们的连接方式以实现不同的拓扑。
一个显著的例子是从雪茄形的常规环面变为如下所示的扭曲雪茄环面。

扭曲环面允许在扭曲的 2D 平面上的芯片之间进行更快的通信。这对于加速全局通信 (all-to-all) 特别有用。
让我们更深入地探讨一下,想象一个具体的场景,这会有所帮助。
使用扭曲环面加速训练
理论上,扭曲环面对张量并行 (Tensor Parallel, TP) 的好处最大,因为每层都有多个 all-gather 和 reduce-scatter 操作。它可能对数据并行 (Data Parallel, DP) 带来中等的好处,因为每个训练步骤也有一个 all-reduce,但这会不那么频繁。
想象一下,我们正在训练一个标准的 decoder-only transformer,并且我们想采用大量的并行化来加速训练。我们将在下面看到两种情况。
场景 #1: 4x4x16 拓扑 (TP + PP; 总共 256 芯片)
我们的 z 轴将是我们的流水线并行 (Pipeline Parallel, PP) 维度,我们的 2D TP 维度将是 4x4。本质上,假设每个层 k 位于 z=k,并且每个层在 16 个芯片上分片。如果没有明确绘制,则假定标准的 OCS 连接(即最近邻)。

我们将在每个 z=k 处扭曲 2D 环面,这使得每个 TP 层中的芯片之间的通信更快。沿着我们的 PP 维度扭曲是不必要的,因为它们主要依赖于点对点通信。
注意: 实际上,当芯片数量大于 4x4 时,扭曲环面才会带来好处。我们在这里使用 4x4 仅为可视化目的。
场景 #2: 16x4x16 拓扑 (DP + TP + PP; 总共 1024 芯片)
作为扩展,我们将在之前的场景中添加一个大小为 4 的 DP 维度。这意味着沿着 x 轴有 4 个场景 #1 的模型。

请注意扭曲环面如何仅限于每个 DP 模型的每个 TP 维度(即,对于给定的 k=1…16,在每个 z=k 处的一个 4x4 的 2D 平面)。DP 维度只有一个环绕连接,以便每行成为一个大小为 16 的水平环。
你可能已经注意到,还有一种 8x8x16 的替代拓扑(即 2x2 DP 维度),但这变得更加复杂,因为我们混合了 DP 和 TP 维度。具体来说,不清楚我们应该如何为 y 轴构建 OCS 环绕连接,同时为每个 TP 维度容纳扭曲环面。
多 Pod 层面 (又称 "Multislice"; TPUv4 为 4096+ 芯片)

TPU 层次结构的最后一层是多 Pod 层面。在这里,你可以将多个 Pod 视为一台大型机器。然而,Pod 之间的通信是通过数据中心网络 (Data-Center Network, DCN)完成的,其带宽低于 ICI。

图示多 Pod 训练如何配置
PaLM 就是这样训练的。它花了 56 天在 6144 个 TPUv4(2个 Pod)上进行训练。下面你可以看到 6 个 Pod 上的 TPU 作业分配:绿色是 PaLM,红色是未分配,其余是其他作业。注意,每个方块是一个 4x4x4 的 TPU 立方体。
实现这一点本身就很困难,但更令人印象深刻的是对开发者体验的关注。具体来说,是关注 “我们如何才能尽可能地抽象化模型扩展的系统/硬件部分?” 这个问题。
谷歌的答案是让 XLA 编译器负责协调大规模芯片间的通信。通过研究人员提供的正确标志(即 DP、FSDP、TP 的并行维度、切片数量等),XLA 编译器会为手头的 TPU 拓扑插入正确的分层集合通信操作。目标是以尽可能少的代码更改来实现大规模训练。
例如,这里是谷歌博客 [1] 中跨多个切片的 all-reduce 操作的分解。

这表明 XLA 编译器负责处理切片之间和切片内部的通信集合操作。
举一个具体的例子,训练模型可能存在如下的 TPU 拓扑。激活值通信通过 ICI 在切片内发生,而梯度通信将通过 DCN 跨切片发生(即跨 DCN DP 维度)。

将图表与现实联系起来
当你看到硬件的实际照片时,将图表与现实联系起来会很有帮助。下面是一个总结。
如果你看过谷歌 TPU 宣传材料的图片,你可能见过下面这张图。

这是 8 个 TPU 机架,每个单元是我们上面看到的 4x4x4 的 3D 环面。一个机架中的每一行有 2 个托盘,这意味着每行有 8 个 TPU 芯片。
这是一个 TPUv4 的单托盘:

注意,图示被简化为只有一个 PCIe 端口,但实际托盘上有 4 个 PCIe 端口(在左侧)——每个 TPU 一个。
下面是一个单芯片:

中间部分是 ASIC,周围的 4 个块是 HBM 堆栈。我们看到的是一个 TPU v4,所以它内部有 2 个 TensorCore,因此总共有 4 个 HBM 堆栈。
我没找到 TPUv4 的芯片平面图,所以这里有一个 TPUv4i 的,它很相似,只是因为它是一个推理芯片,所以只有一个 TensorCore。

请注意 CMEM 在 TPUv4i 的布局上占了相当大的空间。
原文转载自:https://henryhmko.github.io/posts/tpu/tpu.html,经过翻译、校对。
-
FPGA
+关注
关注
1646文章
22059浏览量
619177 -
asic
+关注
关注
34文章
1250浏览量
122520 -
gpu
+关注
关注
28文章
4956浏览量
131439 -
TPU
+关注
关注
0文章
154浏览量
21223 -
KiCAD
+关注
关注
5文章
264浏览量
9581
发布评论请先 登录
AI芯片,需要ASIC
电机控制专用集成电路PDF版
TPU处理器的特性和工作原理

谷歌第七代TPU Ironwood深度解读:AI推理时代的硬件革命

谷歌新一代 TPU 芯片 Ironwood:助力大规模思考与推理的 AI 模型新引擎?
TPU编程竞赛系列|第九届集创赛“算能杯”火热报名中!

评论