Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用

今天,备受广大开发者欢迎的深度学习框架 Keras,正式更新了 3.0 版本,实现了对 PyTorch 和 JAX 的支持,同时性能提升,还能轻松实现大规模分布式训练。

经过 5 个月的公开 Beta 测试,深度学习框架 Keras 3.0 终于面向所有开发者推出。

全新的 Keras 3 对 Keras 代码库进行了完全重写,可以在 JAX、TensorFlow 和 PyTorch 上运行,能够解锁全新大模型训练和部署的新功能。

「Keras 之父」François Chollet 在最新版本发布之前,也是做了多次预告。目前,有 250 + 万的开发者都在使用 Keras 框架。

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

重磅消息:我们刚刚发布了 Keras 3.0!

在 JAX、TensorFlow 和 PyTorch 上运行 Keras

使用 XLA 编译更快地训练

通过新的 Keras 分发 API 解锁任意数量的设备和主机的训练运行

它现在在 PyPI 上上线

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

开发者甚至可以将 Keras 用作低级跨框架语言,以开发自定义组件,例如层、模型或指标。

只需一个代码库,这些组件便可用在 JAX、TensorFlow、PyTorch 中的原生工作流。

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

再次让 Keras 成为多后端

最初的 Keras 可以在 Theano、TensorFlow、CNTK,甚至 MXNet 上运行。

2018 年,由于 Theano 和 CNTK 已停止开发,TensorFlow 似乎成为了唯一可行的选择,于是,Keras 将开发重点放在了 TensorFlow 上。

而到了今年,情况发生了变化。

根据 2023 年 StackOverflow 开发者调查,和 2022 年 Kaggle 机器学习和数据科学调查等显示,

TensorFlow 拥有 55% 到 60% 的市场份额,是 ML 在生产领域的首选。

而 PyTorch 拥有 40% 到 45% 的市场份额,是 ML 在研究领域的首选。

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

与此同时,JAX 虽然市场份额要小得多,但已被 Google DeepMind、Midjourney、Cohere 等生成式 AI 领域的顶级参与者所接受。

于是,开发团队对 Keras 代码库进行了完全重写,新诞生的 Keras 3.0 基于模块化后端架构进行了重构,有能力在任意框架上运行。

同时新的 Keras 也保证了兼容性,比如在使用 TensorFlow 后端时,你可以简单地使用 import keras_core as keras 来替换 from tensorflow import keras

—— 现有的代码将毫无问题地运行,而且由于 XLA 编译,通常性能略有提高。

Keras vs. TensorFlow

小编在这里给大家举一个例子,说明如何从 TensorFlow 的代码转换成 Keras 的形式。

TensorFlow Core Implementation

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

Keras implementation

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

相比之下,我们可以清楚地看到 Keras 带来的简洁性。

TensorFlow 可以对每个变量进行更精细的控制,而 Keras 提供了易用性和快速原型设计的能力。

对于一些开发者来说,Keras 省去了开发中的一些麻烦,降低了编程复杂性,节省了时间成本。

Keras 3.0 新特性

Keras 最大的优势在于,通过出色的 UX、API 设计和可调试性可实现高速开发。

而且,它还是一个经过实战考验的框架,并为世界上一些最复杂、最大规模的 ML 系统提供支持,比如 Waymo 自动驾驶车、YouTube 推荐引擎。

那么,使用新的多后端 Keras 3 还有哪些额外的优势呢?

- 始终为模型获得最佳性能。

在基准测试中,发现 JAX 通常在 GPU、TPU 和 CPU 上提供最佳的训练和推理性能,但结果因模型而异,因为非 XLA TensorFlow 在 GPU 上偶尔会更快。

它能够动态选择为模型提供最佳性能的后端,而无需对代码进行任何更改,这意味着开发者可以以最高效率进行训练和服务。

- 为模型解锁生态系统可选性。

任何 Keras 3 模型都可以作为 PyTorch 模块实例化,可以作为 TensorFlow SavedModel 导出,也可以作为无状态 JAX 函数实例化。

这意味着开发者可以将 Keras 3 模型与 PyTorch 生态系统包,全系列 TensorFlow 部署和生产工具(如 TF-Serving,TF.js 和 TFLite)以及 JAX 大规模 TPU 训练基础架构一起使用。使用 Keras 3 API 编写一个 model.py ,即可访问 ML 世界提供的一切。

- 利用 JAX 的大规模模型并行性和数据并行性。

Keras 3 包含一个全新的分布式 API,即 keras.distribution 命名空间,目前已在 JAX 后端实现(即将在 TensorFlow 和 PyTorch 后端实现)。

通过它,可以在任意模型尺度和聚类尺度上轻松实现模型并行、数据并行以及两者的组合。由于它能将模型定义、训练逻辑和分片配置相互分离,因此使分发工作流易于开发和维护。

- 最大限度地扩大开源模型版本的覆盖面。

想要发布预训练模型?想让尽可能多的人能够使用它吗?如果你在纯 TensorFlow 或 PyTorch 中实现它,它将被大约一半的社区使用。

如果你在 Keras 3 中实现了它,那么任何人都可以立即使用它,无论他们选择的框架是什么(即使他们自己不是 Keras 用户)。在不增加开发成本的情况下实现 2 倍的影响。

- 使用来自任何来源的数据管道。

Keras 3 / fit () / evaluate () predict () 例程与 tf.data.Dataset 对象、PyTorch DataLoader 对象、NumPy 数组、Pandas 数据帧兼容 —— 无论你使用什么后端。你可以在 PyTorch DataLoader 上训练 Keras 3 + TensorFlow 模型,也可以在 tf.data.Dataset 上训练 Keras 3 + PyTorch 模型。

预训练模型

现在,开发者即可开始使用 Keras 3 的各种预训练模型。

所有 40 个 Keras 应用程序模型( keras.applications 命名空间)在所有后端都可用。KerasCV 和 KerasNLP 中的大量预训练模型也适用于所有后端。

其中包括:

- BERT

- OPT

- Whisper

- T5

- Stable Diffusion

- YOLOv8

跨框架开发

Keras 3 能够让开发者创建在任何框架中都相同的组件(如任意自定义层或预训练模型),它允许访问适用于所有后端的 keras.ops 命名空间。

Keras 3 包含 NumPy API 的完整实现,—— 不是「类似 NumPy」,而是真正意义上的 NumPy API,具有相同的函数和参数。比如 ops.matmul、ops.sum、ops.stack、ops.einsum 等函数。

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

Keras 3 还包含 NumPy 中没有的,一组特定于神经网络的函数,例如 ops.softmax, ops.binary_crossentropy, ops.conv 等。

另外,只要开发者使用的运算,全部来自于 keras.ops ,那么自定义的层、损失函数、优化器就可以跨越 JAX、PyTorch 和 TensorFlow,使用相同的代码。

开发者只需要维护一个组件实现,就可以在所有框架中使用它。

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

Keras 架构

下面,我们来稍稍理解一下 Keras 的机制和架构。

在 Keras 中,Sequential 和 Model 类是模型构建的核心,为组装层和定义计算图提供了一个框架。

Sequential 是层的线性堆栈。它是 Model 的子类,专为简单情况而设计,模型由具有一个输入和一个输出的线性层堆栈组成。

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

Sequential 类有以下一些主要特点:

简单性:只需按照要执行的顺序列出图层即可。

自动前向传递:当向 Sequential 模型添加层时,Keras 会自动将每一层的输出连接到下一层的输入,从而创建前向传递,而无需手动干预。

内部状态管理:Sequential 管理层的状态(如权重和偏置)和计算图。调用 compile 时,它会通过指定优化器、损失函数和指标来配置学习过程。

训练和推理:Sequential 类提供了 fit、evaluate 和 predict 等方法,分别用于训练、评估和预测模型。这些方法在内部处理训练循环和推理过程。

Model 类与函数式 API 一起使用,提供了比 Sequential 更大的灵活性。它专为更复杂的架构而设计,包括具有多个输入或输出、共享层和非线性拓扑的模型。

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

Model 类的主要特点有:

层图:Model 允许创建层图,允许一个层连接到多个层,而不仅仅是上一个层和下一个层。

显式输入和输出管理:在函数式 API 中,可以显式定义模型的输入和输出。相比于 Sequential,可以允许更复杂的架构。

连接灵活性:Model 类可以处理具有分支、多个输入和输出以及共享层的模型,使其适用于简单前馈网络以外的广泛应用。

状态和训练管理:Model 类管理所有层的状态和训练过程,同时提供了对层的连接方式,以及数据在模型中的流动方式的更多控制。

Model 类和 Sequential 类都依赖于以下机制:

层注册:在这些模型中添加层时,层会在内部注册,其参数也会添加到模型的参数列表中。

自动微分:在训练过程中,Keras 使用后端引擎(TensorFlow 等)提供的自动微分来计算梯度。这一过程对用户而言是透明的。

后端执行:实际计算(如矩阵乘法、激活等)由后端引擎处理,后端引擎执行模型定义的计算图。

序列化和反序列化:这些类包括保存和加载模型的方法,其中涉及模型结构和权重的序列化。

从本质上讲,Keras 中的 Model 和 Sequential 类抽象掉了定义和管理计算图所涉及的大部分复杂性,使用户能够专注于神经网络的架构,而不是底层的计算机制。

Keras 自动处理各层如何相互连接、数据如何在网络中流动以及如何进行训练和推理操作等错综复杂的细节。

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

对于 Keras 的大更新,有网友使用下面的图片表达自己的看法:

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

虽然小编也不知道为什么要炸 TensorFlow。

还有网友表示刚好可以用上:

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

另一位网友发来贺电,「在 PyTorch 之上使用 Keras 是一项了不起的成就!」

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

当然也有网友唱反调,「我想知道为什么有人会使用 Keras + Torch 而不是普通的 Torch,因为 Torch 与 Tensorflow 不同,它有一组很好的 API」。

Keras 3.0 正式发布:大更新整合 PyTorch、JAX,全球 250 万开发者在用休闲区蓝鸢梦想 - Www.slyday.coM

此时 Tensorflow 的内心:啊对对对,你们说得都对。

参考资料:

https://twitter.com/fchollet/status/1729512791894012011

https://keras.io/keras_3/

广告声明:文内含有的对外跳转链接(包括不限于超链接、二维码、口令等形式),用于传递更多信息,节省甄选时间,结果仅供参考,IT之家所有文章均包含本声明。

相关推荐

  • 友情链接:
  • PHPCMSX
  • 智慧景区
  • 微信扫一扫

    微信扫一扫
    返回顶部

    显示

    忘记密码?

    显示

    显示

    获取验证码

    Close