使用 Optuna 進行超參數微調
Optuna 是一個開放原始碼的 Python 程式庫,可進行超參數微調,可以跨多個計算資源水平擴展。 Optuna 還會與 MLflow 整合,以進行模型和試用追蹤和監視。
安裝 Optuna
使用下列命令來安裝 Optuna 及其整合模組。
%pip install optuna
%pip install optuna-integration # Integration with MLflow
定義搜尋空間並執行 Optuna 最佳化
以下是 Optuna 工作流程中的步驟:
- 定義要優化的目標函式。 在目標函式中,定義超參數搜尋空間。
- 建立 Optuna Study 物件,並呼叫 Study 物件的
optimize
函式來執行微調演算法。
以下是 Optuna 文件的最小範例。
- 定義目標函式
objective
,並呼叫suggest_float
函式,以定義參數x
的搜尋空間。 - 建立一個實驗,並使用100次試驗優化
objective
函數,也就是通過不同的objective
值進行x
函數的100次呼叫。 - 完成最佳研究參數的取得
def objective(trial):
x = trial.suggest_float("x", -10, 10)
return (x - 2) ** 2
study = optuna.create_study()
study.optimize(objective, n_trials=100)
best_params = study.best_params
將 Optuna 試驗平行處理至多台機器
您可以使用 Joblib Apache Spark 後端,將 Optuna 試用散發至 Azure Databricks 叢集中的多部機器。
import joblib
from joblibspark import register_spark
register_spark() # register Spark backend for Joblib
with joblib.parallel_backend("spark", n_jobs=-1):
study.optimize(objective, n_trials=100)
與 MLflow 整合
若要追蹤所有 Optuna 試用的超參數和計量,請在呼叫 MLflowCallback
函式時,使用 Optuna 整合模組的 optimize
。
import mlflow
from optuna.integration.mlflow import MLflowCallback
mlflow_callback = MLflowCallback(
tracking_uri="databricks",
metric_name="accuracy",
create_experiment=False,
mlflow_kwargs={
"experiment_id": experiment_id
}
)
study.optimize(objective, n_trials=100, callbacks=[mlflow_callback])
筆記本範例
此筆記本提供了一個使用 Optuna 選取 scikit-learn 模型及其 鳶尾花數據集超參數的範例。
除了單一機器的 Optuna 工作流程外,筆記本還展示如何
- 透過 Joblib,將 Optuna 試用平行處理至多部電腦
- 使用 MLflow 追蹤實驗運行