使用 sparkdl.xgboost
之 XGBoost 模型的分散式訓練
重要
這項功能處於公開預覽狀態。
注意
sparkdl.xgboost
從 Databricks Runtime 12.0 ML 開始已遭取代,並在 Databricks Runtime 13.0 ML 和更新版本中移除。 關於將工作負載移轉至 xgboost.spark
的資訊,請參閱 已取代 sparkdl.xgboost 模組的移轉指南。
Databricks Runtime ML 包含以 Python xgboost
套件、sparkdl.xgboost.XgboostRegressor
和 sparkdl.xgboost.XgboostClassifier
為基礎的 PySpark 估算器。 您可以根據這些估算器建立 ML 管線。 如需詳細資訊,請參閱適用於 PySpark 管線的 XGBoost。
Databricks 極力建議sparkdl.xgboost
使用者使用 Databricks Runtime 11.3 LTS ML 或更新版本。 舊版的 Databricks 執行階段版本會受到舊版 sparkdl.xgboost
的錯誤影響。
注意
-
sparkdl.xgboost
模組自 Databricks Runtime 12.0 ML 起已遭取代。 Databricks 建議您移轉程式碼以改用xgboost.spark
模組。 請參閱移轉指南。 - 不支援下列來自
xgboost
套件的 parameters:gpu_id
、output_margin
validate_features
。 - 不支援 parameters
sample_weight
、eval_set
和sample_weight_eval_set
。 請改用 parametersweightCol
與validationIndicatorCol
。 如需詳細資訊,請參閱適用於 PySpark 管線的 XGBoost。 - 不支援 parameters
base_margin
和base_margin_eval_set
。 請改用參數baseMarginCol
。 如需詳細資訊,請參閱適用於 PySpark 管線的 XGBoost。 - 參數
missing
具有與xgboost
套件不同的語意。 在xgboost
套件中,不論missing
的值為何,SciPy 疏鬆矩陣中的零 values 都會被視為遺漏 values。 針對sparkdl
套件中的 PySpark 估算器,除非 setmissing=0
,否則 Spark 疏鬆向量中的零 values 不會被視為遺漏 values。 如果您有稀疏的訓練數據集(大部分特徵 values 有缺失),Databricks 建議設定missing=0
以降低記憶體使用量並提升效能。
分散式訓練
Databricks Runtime ML 支援使用 num_workers
參數的分散式 XGBoost 訓練。 若要使用分散式訓練,請建立分類器或回歸器,並將 setnum_workers
設定為小於或等於叢集上 Spark 工作任務插槽的總數。 若要使用所有 Spark 任務插槽,setnum_workers=sc.defaultParallelism
。
例如:
classifier = XgboostClassifier(num_workers=sc.defaultParallelism)
regressor = XgboostRegressor(num_workers=sc.defaultParallelism)
分散式訓練的限制
- 您無法使用
mlflow.xgboost.autolog
搭配分散式 XGBoost。 - 您無法使用
baseMarginCol
搭配分散式 XGBoost。 - 您無法在已啟用自動調整功能的叢集上使用分散式 XGBoost。 如需停用自動調整的指示,請參閱啟用自動調整。
GPU 訓練
注意
Databricks Runtime 11.3 LTS ML 包含 XGBoost 1.6.1,其不支援具有計算功能 5.2 及以下的 GPU 叢集。
Databricks Runtime 9.1 LTS ML 和更新版本支援 XGBoost 訓練的 GPU 叢集。 若要使用 GPU 叢集,setuse_gpu
到 True
。
例如:
classifier = XgboostClassifier(num_workers=N, use_gpu=True)
regressor = XgboostRegressor(num_workers=N, use_gpu=True)
疑難排解
在多節點訓練期間,如果您收到 NCCL failure: remote process exited or there was a network error
訊息,通常表示 GPU 之間的網路通訊發生問題。 NCCL (NVIDIA 集體通訊程式庫) 無法使用特定網路介面進行 GPU 通訊時,就會發生此問題。
若要解決此問題,請為 set 叢集的 sparkConf 設置 spark.executorEnv.NCCL_SOCKET_IFNAME
到 eth
。 這基本上會將節點中所有工作者的環境變數 NCCL_SOCKET_IFNAME
設定為 eth
。
範例筆記本
此 Notebook 顯示搭配 Spark MLlib 使用 Python 套件 sparkdl.xgboost
。
sparkdl.xgboost
套件自 Databricks Runtime 12.0 ML 起已遭取代。