使用 Mosaic 流式处理加载数据
本文介绍如何使用 Mosaic 流式处理将数据从 Apache Spark 转换为与 PyTorch 兼容的格式。
Mosaic 流式处理是一个开源数据加载库。 它可直接从已加载为 Apache Spark 数据帧的数据集对深度学习模型进行单节点或分布式训练和评估。 Mosaic 流式处理主要支持 Mosaic Composer,但也与本机 PyTorch、PyTorch Lightning 和 TorchDistributor 集成。 与传统 PyTorch DataLoaders 相比,Mosaic 流式处理提供了一系列优势,包括:
- 与任何数据类型(包括图像、文本、视频和多模式数据)的兼容性。
- 支持主要云存储提供程序(AWS、OCI、GCS、Azure、Databricks UC 卷和任何与 S3 兼容的对象存储,例如 Cloudflare R2、Coreweave、Backblaze b2 等)
- 最大限度地保证正确性,以及最大程度地提升性能、灵活性和易用性。 有关详细信息,请查看相应的主要功能页。
有关 Mosaic 流式处理的一般信息,请查看流式处理 API 文档。
注意
Mosaic 流式处理已预安装到 Databricks Runtime 15.2 ML 及更高版本。
使用 Mosaic 流式处理从 Spark 数据帧加载数据
Mosaic 流式处理提供了一个简单的工作流,用于从 Apache Spark 转换为 Mosaic 数据分片 (MDS) 格式,然后可以加载该格式在分布式环境中使用。
建议的工作流为:
- 使用 Apache Spark 来加载数据,还可以选择对数据进行预处理。
- 使用
streaming.base.converters.dataframe_to_mds
将数据帧保存到磁盘进行暂时存储和/或保存到 Unity Catalog 卷进行持久存储。 此数据将以 MDS 格式存储,并且可以通过对压缩和哈希的支持进行进一步优化。 高级用例还可以包括使用 UDF 对数据进行预处理。 有关详细信息,请查看将 Spark 数据帧转换为 MDS 的教程。 - 使用
streaming.StreamingDataset
将必要的数据加载到内存中。StreamingDataset
是 PyTorch 的 IterableDataset 的一个版本,它具有可弹性确定的随机处理,可实现快速的中时期恢复。 有关详细信息,请查看 StreamingDataset 文档。 - 使用
streaming.StreamingDataLoader
加载训练/评估/测试所需的数据。StreamingDataLoader
是 PyTorch 的 DataLoader 的一个版本,它提供额外的检查点/恢复接口,用于跟踪此设置级别中模型看到的示例数。
有关端到端示例,请参阅以下笔记本:
使用 Mosaic 流式处理笔记本简化从 Spark 到 PyTorch 的数据加载
故障排除:身份验证错误
如果使用 StreamingDataset
从 Unity 目录卷加载数据时看到以下错误,请设置环境变量,如下所示。
ValueError: default auth: cannot configure default credentials, please check https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication to configure credentials for your preferred authentication method.
注意
如果在使用 TorchDistributor
运行分布式训练时看到此错误,则还必须在工作器节点上设置环境变量。
db_host = "https://your-databricks-host.databricks.com"
db_token = "YOUR API TOKEN" # Create a token with either method from https://docs.databricks.com/en/dev-tools/auth/index.html#databricks-authentication-methods
def your_training_function():
import os
os.environ['DATABRICKS_HOST'] = db_host
os.environ['DATABRICKS_TOKEN'] = db_token
# The above function can be distributed with TorchDistributor:
# from pyspark.ml.torch.distributor import TorchDistributor
# distributor = TorchDistributor(...)
# distributor.run(your_training_function)