0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

如何在TPU上使用JAX训练GPT-2模型

谷歌开发者 ? 来源:谷歌开发者 ? 2025-09-03 11:39 ? 次阅读
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

作者 / 魏巍,开发技术推广工程师

如果您对如何使用 JAX 从头开始构建语言模型感到好奇,那么本文非常适合您。我们在 2025 年 Google Cloud Next 大会上举办了一场关于此主题的研讨会,并获得了一些很好的反馈,我们也为所有无法参会的开发者编写了这份指南。

本文和代码示例将引导您构建并预训练 GPT-2 模型,了解 JAX 如何直接利用 Google TPU 的强大能力。您可以使用 Colab 或 Kaggle 中的 TPU 免费运行整个项目,并获取完整的Notebook。

Notebook

https://github.com/windmaple/LLM_from_scratch.JAX/tree/main/02.GPT2-pretraining

这是一个实践教程,如果您还不熟悉 JAX,我们建议您从《PyTorch 开发者指南: JAX 基础知识》入手。

PyTorch 开发者指南: JAX 基础知识

https://cloud.google.com/blog/products/ai-machine-learning/guide-to-jax-for-pytorch-developers

首先,让我们快速了解一下将要用到的工具。

JAX 生态系统

在开始构建模型之前,让我们先简要介绍一下 JAX 生态系统。JAX 生态系统采用模块化方法,通过 JAX 核心提供核心数值处理能力,而一系列丰富的库则在此基础上构建而成,以满足不同应用的特定需求,如用于构建神经网络的Flax、用于检查点和模型持久性的Orbax以及用于优化的Optax(在本文中,这 3 个工具都将被用到)。内置函数转换,如 autograd、矢量化和 JIT 编译,加上强大的性能和易于使用的 API,使 JAX 非常适合训练大语言模型。

JAX 生态系统

https://docs.jax.dev/en/latest/#ecosystem

Flax

https://github.com/google/flax

Orbax

https://github.com/google/orbax

Optax

https://github.com/google-deepmind/optax

入门指南: 构建您的 GPT-2 模型

OpenAI 此前发布了GPT-2 模型代码和权重,这为我们提供了宝贵的参考资料,并且社区也付出了很多努力来复现该模型,例如nanoGPT。以下是 GPT-2 的高层级模型架构图:

dedd83ce-84bb-11f0-a18e-92fbcf53809c.png

GPT-2 模型代码和权重

https://github.com/openai/gpt-2

nanoGPT

https://github.com/karpathy/nanoGPT

我们将使用NNX (新的 Flax 接口)来构建 GPT-2 模型。简洁起见,我们重点关注 Transformer Block,这是现代大语言模型的关键所在。Transformer Block 会捕获任何序列的长程依赖关系,并构建丰富的上下文理解。GPT-2 Transformer Block 由 2 个 LayerNorm 层、1 个多头注意力 (MHA) 层、2 个 Dropout 层、2 个线性投影层和 2 个残差连接组成。因此,我们首先需要在TransformerBlock类的__init__函数中定义这些层:

classTransformerBlock(nnx.Module):
 def__init__(
    self,
    embed_dim:int,
    num_heads:int,
    ff_dim:int,
    dropout_rate:float,
    rngs: nnx.Rngs,
  ):
    self.layer_norm1 = nnx.LayerNorm(
      epsilon=1e-6, num_features=embed_dim, rngs=rngs
    )
    self.mha = nnx.MultiHeadAttention(
      num_heads=num_heads, in_features=embed_dim, rngs=rngs
    )
    self.dropout1 = nnx.Dropout(rate=dropout_rate)
    self.layer_norm2 = nnx.LayerNorm(
      epsilon=1e-6, num_features=embed_dim, rngs=rngs
    )
    self.linear1 = nnx.Linear(
      in_features=embed_dim, out_features=ff_dim, rngs=rngs
    )
    self.linear2 = nnx.Linear(
      in_features=ff_dim, out_features=embed_dim, rngs=rngs
    )
    self.dropout2 = nnx.Dropout(rate=dropout_rate)

NNX (新的 Flax 接口)

https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html#

接下来,我们需要在__call__函数中对这些层进行组合:

classTransformerBlock(nnx.Module):
 def__call__(self, inputs, training:bool=False):
    input_shape = inputs.shape
    bs, seq_len, emb_sz = input_shape


    attention_output = self.mha(
      inputs_q=self.layer_norm1(inputs),
      mask=causal_attention_mask(seq_len),
      decode=False,
    )
    x = inputs + self.dropout1(
      attention_output, deterministic=nottraining
    )


   # MLP
    mlp_output = self.linear1(self.layer_norm2(x))
    mlp_output = nnx.gelu(mlp_output)
    mlp_output = self.linear2(mlp_output)
    mlp_output = self.dropout2(
      mlp_output, deterministic=nottraining
    )


   returnx + mlp_output

如果您使用过任何其他机器学习框架 (如 PyTorch 或 TensorFlow) 来训练语言模型,那么您对这段代码应该非常熟悉。但 JAX 具有通过SPMD(Single Program Multiple Data) 自动并行运行代码的强大能力。这项功能至关重要,因为我们将在多个加速器 (多个 TPU 核心) 上运行代码。让我们来看看它的工作原理

SPMD

https://docs.jax.dev/en/latest/sharded-computation.html

要执行 SPMD,首先我们需要确保自己使用的是 TPU。如果您使用的是 Colab 或 Kaggle,请选择 TPU 运行时 (您也可以使用 Cloud TPU 虚拟机)。

import jax
jax.devices()


# Free-tier Colab offers TPU v2:
#[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
# TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
# TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
# TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
# TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
# TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
# TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
# TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Colab 和 Kaggle 提供 TPU v2 或 v3,其中含有 8 个独立的 TPU 核心。TPU v3 托盘的外观如下所示:

def0e464-84bb-11f0-a18e-92fbcf53809c.png

训练您的 GPT-2 模型

为了高效训练 GPT-2 模型,我们将通过 SPMD 让所有 TPU 核心协同运行,并利用 JAX 中的数据并行。为此,我们定义了一个硬件网格:

mesh= jax.make_mesh((8,1), ('batch','model'))

数据并行

https://en.wikipedia.org/wiki/Data_parallelism

我们可以将网格视为加速器的 2D 矩阵。在本例中,我们为网格定义了两个轴:batch轴和model轴。因此,我们总共有 8 x 1 个核心,也就是 8 个核心。这些轴决定了我们如何划分数据和模型参数。如果之后想尝试其他并行方案,我们可以对这些轴进行调整。

现在,我们通过告诉 JAX 如何使用 "model" 轴划分模型参数来更改__init__函数。这是通过在初始化权重张量 (weight tensors) 时添加nnx.with_partitioning来实现的: 对于像 LayerNorm 缩放/偏置张量这样的 1D 权重张量 (weight tensors),我们直接沿着 "model" 轴对它们进行分片;对于像 MHA 和线性内核张量这样的 2D 权重张量,我们沿着model轴对第二维度进行分片。

classTransformerBlock(nnx.Module):
 def__init__(
    self,
    embed_dim:int,
    num_heads:int,
    ff_dim:int,
    dropout_rate:float,
    rngs: nnx.Rngs,
  ):
    self.layer_norm1 = nnx.LayerNorm(
      epsilon=1e-6, num_features=embed_dim,rngs=rngs, rngs=rngs,
      scale_init=nnx.with_partitioning(
        nnx.initializers.ones_init(),
        ("model"),
      ),
      bias_init=nnx.with_partitioning(
        nnx.initializers.zeros_init(),
       ("model"),
      ),
    )
    self.mha = nnx.MultiHeadAttention(
      num_heads=num_heads, in_features=embed_dim,
      kernel_init=nnx.with_partitioning(
        nnx.initializers.xavier_uniform(),
       (None,"model"),
      ),
      bias_init=nnx.with_partitioning(
        nnx.initializers.zeros_init(),
       ("model"),
      ),
    )
   # Other layers in the block are omitted for brevity

我们需要像这样划分其他层,以便为整个 GPT-2 模型启用模型张量并行。即使我们在本教程中不会使用模型张量并行,实现这一功能仍然是比较好的做法,因为随着模型规模的增长,我们将来可能需要对模型参数进行分区。实现后,我们只需更改一行代码即可立即运行更大的模型。例如:

mesh= jax.make_mesh((4,2), ('batch','model'))

接下来,我们需要定义loss_fn和train_step函数,与此前文章类似。train_step()函数会计算交叉熵损失函数的梯度,并通过优化器更新权重,然后在循环中被调用来训练模型。为了获得最佳性能,我们使用@nnx.jit装饰器对这两个函数进行 JIT 编译,因为它们属于计算密集型函数。

@nnx.jit
defloss_fn(model, batch):
  logits = model(batch[0])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch[1]
  ).mean()
 returnloss, logits




@nnx.jit
deftrain_step(
  model: nnx.Module,
  optimizer: nnx.Optimizer,
  metrics: nnx.MultiMetric,
  batch,
):
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  metrics.update(loss=loss, logits=logits, lables=batch[1])
  optimizer.update(grads)

此前文章

https://cloud.google.com/blog/products/ai-machine-learning/guide-to-jax-for-pytorch-developers

对于优化器,我们使用 Optax 中的 AdamW 以及余弦衰减调度。您也可以在 Optax 中试用其他优化器或调度计划。

schedule = optax.cosine_decay_schedule(
  init_value=init_learning_rate, decay_steps=max_steps
)
optax_chain = optax.chain(
  optax.adamw(learning_rate=schedule, weight_decay=weight_decay)
)
optimizer = nnx.Optimizer(model, optax_chain)

其他优化器

https://optax.readthedocs.io/en/latest/api/optimizers.html

调度计划

https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html

最后,我们需要创建一个简单的训练循环。

while True:
  input_batch, target_batch =get_batch("train")


 train_step(
    model,
    optimizer,
    train_metrics,
    jax.device_put(
      (input_batch, target_batch),
     NamedSharding(mesh,P("batch", None)),
    ),
  )


  step +=1
  if step > max_steps:
    break

请注意我们使用jax.device_put函数沿着 batch 轴对输入数据进行分区。在这种情况下,JAX 将启用数据并行,并通过自动插入通信集合 (AllReduce) 将所有内容整合在一起,同时尽可能多地实现计算与通信的重叠。有关并行计算更深入的讨论,请参阅 JAX 的并行编程入门文档。

并行编程入门

https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#intro-and-a-quick-example

模型此时应处于训练状态,如果使用权重和偏差来跟踪运行情况,我们便可以观察训练损失。以下是训练 GPT-2 124M 模型的测试运行结果:

df146da8-84bb-11f0-a18e-92fbcf53809c.png

权重和偏差

https://wandb.ai/site

如果使用 Kaggle TPU v3,训练时间大约为 7 个小时 (我们可以不中断地使用 Kaggle TPU v3 9 个小时);但如果使用Trillium,训练时间将缩短至约 1.5 个小时 (请注意,Trillium 的每个芯片配备 32G 高带宽内存 (HBM),因此我们可以将批量大小加倍,并将训练步数减半)。

Trillium

https://cloud.google.com/blog/products/compute/trillium-tpu-is-ga

最终的损失情况与nanoGPT 的损失情况大致相符。我们在编写此代码示例时对 nanoGPT 进行了研究。

df270288-84bb-11f0-a18e-92fbcf53809c.png

nanoGPT 的损失情况

https://github.com/karpathy/nanoGPT/tree/master?tab=readme-ov-file#baselines

如果使用 Cloud TPU,我们还可以通过 "tpu-info" 命令 (Cloud TPU 监控调试包的一部分) 或权重和偏差仪表盘监控 TPU 利用率。我们的 TPU 正在全力运行!

df3f1f4e-84bb-11f0-a18e-92fbcf53809c.png

Cloud TPU 监控调试

https://github.com/AI-Hypercomputer/cloud-tpu-monitoring-debugging

完成模型训练后,我们可以使用Orbax保存模型:

checkpointer = orbax.PyTreeCheckpointer()
train_state = nnx.pure(nnx.state(model))
checkpointer.save(checkpoint_path, train_state)

Orbax

https://github.com/google/orbax

后续步骤: 探索高级 LLM 训练和扩展

这基本上就是我们训练 GPT-2 模型所需了解的全部内容。您可以在完整的Notebook中找到其他详细信息,如数据加载、超参数、指标等。

Notebook

https://github.com/windmaple/LLM_from_scratch.JAX/tree/main/02.GPT2-pretraining

当然,GPT-2 如今还是一个小模型,许多前沿实验室正在训练拥有数千亿参数的模型。但是,现在您已经学习了如何使用 JAX 和 TPU 构建小语言模型,为深入了解如何扩展模型做好了准备。

如何扩展模型

https://jax-ml.github.io/scaling-book/

此外,您既可以使用MaxText来训练预构建的前沿 LLM,也可以通过参考JAX LLM 示例或Stanford Marin 模型来学习如何从头开始构建最新的模型。

MaxText

https://github.com/AI-Hypercomputer/maxtext

JAX LLM 示例

https://github.com/jax-ml/jax-llm-examples/

Stanford Marin 模型

https://developers.googleblog.com/en/stanfords-marin-foundation-model-first-fully-open-model-developed-using-jax/

我们期待看到您使用 JAX 和 TPU 构建的出色模型!

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 模型
    +关注

    关注

    1

    文章

    3573

    浏览量

    50899
  • 代码
    +关注

    关注

    30

    文章

    4913

    浏览量

    71428
  • TPU
    TPU
    +关注

    关注

    0

    文章

    156

    浏览量

    21312
  • pytorch
    +关注

    关注

    2

    文章

    811

    浏览量

    14166

原文标题:实战指南|手把手教您在 TPU 上免费使用 JAX 训练 GPT-2 模型

文章出处:【微信号:Google_Developers,微信公众号:谷歌开发者】欢迎添加关注!文章转载请注明出处。

收藏 人收藏
加入交流群
微信小助手二维码

扫码添加小助手

加入工程师交流群

    评论

    相关推荐
    热点推荐

    用PaddleNLP在4060单卡实践大模型训练技术

    手把手教您如何在单张消费级显卡,利用PaddleNLP实践OpenAI的GPT-2模型的预训练GPT
    的头像 发表于 02-19 16:10 ?1335次阅读
    用PaddleNLP在4060单卡<b class='flag-5'>上</b>实践大<b class='flag-5'>模型</b>预<b class='flag-5'>训练</b>技术

    如何利用Google Colab的云TPU加速Keras模型训练

    TPU包含8个TPU核,每个核都作为独立的处理单元运作。如果没有用上全部8个核心,那就没有充分利用TPU。为了充分加速训练,相比在单GPU
    的头像 发表于 11-16 09:10 ?1.1w次阅读

    OpenAI发布了一个“逆天”的AI模型——GPT2整个模型包含15亿个参数

    能有这样出色的表现,不是没有原因的,GPT-2各种特定领域的语言建模任务中都取得了很好的分数。作为一个没有经过任何领域数据专门训练模型,它的表现,比那些专为特定领域数据集(例如维基百科,新闻,书籍)
    的头像 发表于 03-07 14:45 ?9068次阅读

    OpenAI发布一款令人印象深刻的语言模型GPT-2

    今年2月,OpenAI发布了一款令人印象深刻的语言模型GPT-2,它可以写短篇小说、诗歌,甚至轻松辨别《哈利波特》和《指环王》中的角色。最近,一位加拿大工程师用它创建了一个向公众开放的文本生成器,只需提供一个句子,机器便能自动生
    的头像 发表于 05-17 18:48 ?4892次阅读

    布朗大学90后研究生:我们复现了15亿参数GPT-2模型,你也行!

    模型的实现基于Grover模型,并修改其代码库以匹配GPT-2的语言建模训练目标。由于他们的模型是在类似的大型语料库上进行
    的头像 发表于 09-01 07:11 ?3674次阅读

    OpenAI宣布,发布了7.74亿参数GPT-2语言模型

    就在本周,OpenAI宣布,发布了7.74亿参数GPT-2语言模型,15.58亿的完整模型也有望于几个月内发布,并将GPT-2这6个月的进展情况在博客
    的头像 发表于 09-01 09:10 ?3324次阅读

    和AI聊天,自然语言模型 GPT-2可能会推出个人信息

    Stroudsburg……” 自然语言模型 GPT-2就像是收到了某种暗号,立刻“送出”一套 个人信息:姓名、电话号码,还有地址、邮箱和传真 (部分信息已打码)。 这可不是GPT-2瞎编的,而是真实存在的个人信息!这些个人信息
    的头像 发表于 01-02 09:22 ?2798次阅读

    GPT系列的“高仿” 最大可达GPT-3大小 自主训练

    虽然GPT-3没有开源,却已经有人在复刻GPT系列的模型了。 例如,慕尼黑工业大学的Connor Leahy,此前用200个小时、6000RMB,复现了GPT-2。 又例如,基于150
    的头像 发表于 02-13 09:24 ?3148次阅读

    使用NVIDIA TensorRT优化T5和GPT-2

    在这篇文章中,我们向您介绍了如何将拥抱脸 PyTorch T5 和 GPT-2 模型转换为优化的 TensorRT 推理引擎。 TensorRT 推理机用作原始 HuggingFace T5
    的头像 发表于 03-31 17:25 ?4306次阅读
    使用NVIDIA TensorRT优化T5和<b class='flag-5'>GPT-2</b>

    基于OpenAI的GPT-2的语言模型ProtGPT2可生成新的蛋白质序列

    人类语言与蛋白质有很多共同点,至少在计算建模方面。这使得研究团队将自然语言处理(NLP)的新方法应用于蛋白质设计。其中,德国Bayreuth大学Birte H?cker的蛋白质设计实验室,描述了基于OpenAI的GPT-2的语言模型ProtGPT
    的头像 发表于 09-08 16:24 ?2913次阅读

    GPT/GPT-2/GPT-3/InstructGPT进化之路

    在预训练阶段,GPT 选择 transformer 的 decoder 部分作为模型的主要模块,transformer 是 2017年 google 提出的一种特征抽取模型
    的头像 发表于 03-03 11:14 ?4512次阅读

    DeepSpeed里面和Zero相关技术教程

    和NVMe 分配大规模Megatron-LM模型 以内存为中心的分块优化 提取权重 ZeRO-Offload概述 训练环境 在单个 V100 GPU 训练10B的
    的头像 发表于 06-12 10:25 ?4994次阅读
    DeepSpeed里面和Zero相关技术教程

    DeepSpeed结合Megatron-LM训练GPT2模型笔记

    本文基于DeepSpeedExamples仓库中给出的Megatron相关例子探索一下训练GPT2模型的流程。主要包含3个部分,第一个部分是基于原始的Megatron如何训练
    的头像 发表于 06-19 14:45 ?4454次阅读
    DeepSpeed结合Megatron-LM<b class='flag-5'>训练</b><b class='flag-5'>GPT2</b><b class='flag-5'>模型</b>笔记

    TPU-MLIR之量化感知训练

    TPU-MLIR之量化感知训练
    的头像 发表于 08-21 10:47 ?1273次阅读
    <b class='flag-5'>TPU</b>-MLIR之量化感知<b class='flag-5'>训练</b>

    用PaddleNLP为GPT-2模型制作FineWeb二进制预训练数据集

    作者:算力魔方创始人/英特尔创新大使刘力 《用PaddleNLP在4060单卡实践大模型训练技术》发布后收到读者热烈反响,很多读者要求进一步讲解更多的技术细节。本文主要针对大语言模型
    的头像 发表于 03-21 18:24 ?2973次阅读
    用PaddleNLP为<b class='flag-5'>GPT-2</b><b class='flag-5'>模型</b>制作FineWeb二进制预<b class='flag-5'>训练</b>数据集