PyTorch 2.2 大更新:集成 FlashAttention

新的一年,PyTorch 也迎来了重大更新,PyTorch 2.2 集成了 FlashAttention-2 和 AOTInductor 等新特性,计算性能翻倍。

继去年十月份的 PyTorch 大会发布了 2.1 版本之后,全世界各地的 521 位开发者贡献了 3628 个提交,由此形成了最新的 PyTorch 2.2 版本。

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

新的版本集成了 FlashAttention-2,使得 scaled_dot_product_attention (SDPA)相较于之前的版本有了约 2 倍的性能提升。

PyTorch 2.2 还引入了一个新的 TorchInductor 提前扩展,称为 AOTInductor,旨在为非 python 服务器端编译和部署 PyTorch 程序。

PyTorch 中的 torch.distributed 支持了一个叫作 device_mesh 的新抽象,用于初始化和表示 ProcessGroups。

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

另外,PyTorch 2.2 提供了一个标准化的、可配置的日志记录机制,——TORCH_LOGS。

PyTorch 2.2 还对 torch.compile 做了许多改进,包括改进了对编译优化器的支持,以及 TorchInductor 融合和布局优化。

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

最后值得注意的是,PyTorch 将放弃对 macOS x86 的支持,PyTorch 2.2.x 是支持 macOS x64 的最后一个版本。

PyTorch 2.2 新特性

首先请注意,如果从源代码构建 PyTorch 2.2,需要 GCC 9.4 或更高版本,PyTorch 代码库已从 C++ 14 迁移到 C++ 17。

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

FlashAttention-2

FlashAttention-2 通过优化 GPU 上不同线程块和 warps 之间的工作分区,来解决占用率低或不必要的共享内存读写。

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

FlashAttention-2 调整了算法以减少非 matmul 的计算量,同时提升了 Attention 计算的并行性(即使是单个头,也可以跨不同的线程块,以增加占用率),在每个线程块中,优化 warps 之间的工作分配,以减少通过共享内存的通信。

PyTorch 2.2 将 FlashAttention 内核更新到了 v2 版本,不过需要注意的是,之前的 Flash Attention 内核具有 Windows 实现,Windows 用户可以强制使用 sdp_kernel,仅启用 Flash Attention 的上下文管理器。

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

而在 2.2 中,如果必须使用 sdp_kernel 上下文管理器,请使用 memory efficient 或 math 内核(在 Windows 上)。

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

在 FlashAttention-2 的加持之下,torch.nn.functional.scaled_dot_product_attention 的速度提升了大约 2 倍,在 A100 GPU 上达到了理论计算峰值的 50%-73%。

AOTInductor

AOTInductor 是 TorchInductor 的扩展,用于处理导出的 PyTorch 模型,对其进行优化,并生成共享库以及其他相关工件。

这些编译的工件可以部署在非 Python 环境中,经常用于服务器端的推理。

下面的示例演示了如何调用 aot_compile 将模型转换为共享库。

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

AOTInductor 支持与 Inductor 相同的后端,包括 CUDA、ROCm 和 CPU。

TORCH_LOGS

PyTorch 2.2 提供了一个标准化的、可配置的日志记录机制,可用于分析各种子系统的状态,例如编译和分布式操作可以通过 TORCH_LOGS 环境变量启用日志。比如通过在命令行中修改环境变量:

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

将 TorchDynamo 的日志级别设置为 logging.ERROR,将 TorchInductor 的日志级别设置为 logging.DEBUG。

当然也可以在代码中以 API 的形式使用:

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

torch.distributed.device_mesh

PyTorch 2.2 引入了一个新的抽象,用于表示分布式并行中涉及的 ProcessGroups,称为 torch.distributed.device_mesh。

为分布式训练设置分布式通信器(NCCL)是一件麻烦的事情。用户需要编写不同并行度的工作负载,并为每个并行度手动设置和管理 NCCL 通信器(ProcessGroup )。

这个过程可能很复杂,容易出错。而 DeviceMesh 可以简化此过程,使其更易于管理。

DeviceMesh 是管理 ProcessGroup 的更高级别的抽象。它允许用户毫不费力地创建节点间和节点内进程组,而不必担心如何为不同的子进程组正确设置等级。

例如,数组的其中一个维度可以表示 FSDP 中的数据并行(data parallelism),而另一个维度可以表示 FSDP 中的张量并行(tensor parallelism)。

用户还可以通过 DeviceMesh 轻松管理底层 process_groups,以实现多维并行。

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

DeviceMesh 在处理多维并行性(如 3D 并行)时很有用。如上图所示,当你的并行解决方案需要跨主机和每个主机内部进行通信时,可以创建一个 2D 网格,用于连接每个主机中的设备,并以同构设置将每个设备与其他主机上的对应设备连接起来。

借助 init_device_mesh (),我们可以在短短两行内完成上面这个 2D 设置:

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

而如果不使用 DeviceMesh,我们大概需要自己写下面这一堆代码:

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

当然,如果需要,我们仍然可以访问底层 ProcessGroup:

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

优化器的改进

大概有以下几点:

编译优化器在所有基准测试中都提高了性能:HuggingFace +18%、TorchBench +19%、TIMM +8% E2E;

编译的优化器增加对 cudagraphs 的支持;

对测试套件中所有模型进行平均,每个测试套件的基准测试平均编译时间增加约 40 秒;正在进行的优化可能会将其降低到 30 秒以下。

用于多张量优化器编译的 inductor 中缺少的主要功能是 foreach 算子的高效编码生成。

在调度器内部,将所有在下放过程中注册的缓冲区列表凝聚到 ForeachKernelSchedulerNodes 中(FusedSchedulerNode 的子类)。

为了检查融合是否合法,每个内部 SchedulerNode 执行的写操作必须与消费 SchedulerNode 在同一列表索引处的读操作相匹配。

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

此外,正常的垂直融合规则必须允许在消费者和生产者 SchedulerNode 列表的每个索引处进行融合。

如果满足了这些条件,ForeachKernelSchedulerNode 将垂直融合成一个 ForeachKernelSchedulerNode,其中每个列表上的相应点操作都将被融合。

通过实现这种融合,可以将一系列 foreach 运算融合到单个内核中,从而实现多张量优化器的完全融合。

性能改进

TorchInductor 中添加了许多性能优化,包括对 torch.concat 的水平融合支持、改进的卷积布局优化、以及改进 scaled_dot_product_attention 模式匹配。

PyTorch 2.2 大更新:集成 FlashAttention休闲区蓝鸢梦想 - Www.slyday.coM

PyTorch 2.2 还包括 aarch64 的许多性能增强,包括对 mkldnn 权重预打包的支持、改进的 ideep 基元缓存,以及通过对 OneDNN 的固定格式内核改进,来提高推理速度。

参考资料:

https://pytorch.org/blog/pytorch2-2/

广告声明:文内含有的对外跳转链接(包括不限于超链接、二维码、口令等形式),用于传递更多信息,节省甄选时间,结果仅供参考,IT之家所有文章均包含本声明。

相关推荐

  • 友情链接:
  • PHPCMSX
  • 智慧景区
  • 微信扫一扫

    微信扫一扫
    返回顶部

    显示

    忘记密码?

    显示

    显示

    获取验证码

    Close