Microsoft Fabric에서 scikit-learn을 통해 모델 학습
이 문서에서는 scikit-learn 모델의 반복을 학습하고 추적하는 방법을 설명합니다. Scikit-learn 은 지도 및 비지도 학습에 자주 사용되는 인기 있는 오픈 소스 기계 학습 프레임워크입니다. 이 프레임워크는 모델 맞춤, 데이터 전처리, 모델 선택, 모델 평가 등을 위한 도구를 제공합니다.
필수 조건
Notebook 내에 scikit-learn을 설치해야 합니다. 다음 명령을 사용하여 사용자 환경에 scikit-learn 버전을 설치하거나 업그레이드할 수 있습니다.
pip install scikit-learn
기계 학습 실험 설정
MLFLow API를 사용하여 기계 학습 실험을 만들 수 있습니다. MLflow set_experiment()
함수는 아직 기계 학습 실험이 없을 경우 sample-sklearn이라는 이름의 새 기계 학습 실험을 만듭니다.
Notebook에서 다음 명령을 실행하여 실험을 만듭니다.
import mlflow
mlflow.set_experiment("sample-sklearn")
scikit-learn 모델 학습
실험을 설정한 후 샘플 데이터 세트와 로지스틱 회귀 분석 모델을 만듭니다. 다음 코드는 MLflow 실행을 시작하고 메트릭, 매개 변수 및 최종 로지스틱 회귀 분석 모델을 추적합니다. 최종 모델을 생성한 후에는 추가 추적을 위해 결과 모델을 저장할 수 있습니다.
Notebook에서 다음 코드를 실행하고 샘플 데이터 세트 및 로지스틱 회귀 분석 모델을 만듭니다.
import mlflow.sklearn
import numpy as np
from sklearn.linear_model import LogisticRegression
from mlflow.models.signature import infer_signature
with mlflow.start_run() as run:
lr = LogisticRegression()
X = np.array([-2, -1, 0, 1, 2, 1]).reshape(-1, 1)
y = np.array([0, 0, 1, 1, 1, 0])
lr.fit(X, y)
score = lr.score(X, y)
signature = infer_signature(X, y)
print("log_metric.")
mlflow.log_metric("score", score)
print("log_params.")
mlflow.log_param("alpha", "alpha")
print("log_model.")
mlflow.sklearn.log_model(lr, "sklearn-model", signature=signature)
print("Model saved in run_id=%s" % run.info.run_id)
print("register_model.")
mlflow.register_model(
"runs:/{}/sklearn-model".format(run.info.run_id), "sample-sklearn"
)
print("All done")
샘플 데이터 세트에 모델 로드 및 평가
모델을 저장한 후 추론을 위해 로드할 수 있습니다.
Notebook에서 다음 코드를 실행하고 모델을 로드한 다음 샘플 데이터 세트에 대한 유추를 실행합니다.
# Inference with loading the logged model
from synapse.ml.predict import MLflowTransformer
spark.conf.set("spark.synapse.ml.predict.enabled", "true")
model = MLflowTransformer(
inputCols=["x"],
outputCol="prediction",
modelName="sample-sklearn",
modelVersion=1,
)
test_spark = spark.createDataFrame(
data=np.array([-2, -1, 0, 1, 2, 1]).reshape(-1, 1).tolist(), schema=["x"]
)
batch_predictions = model.transform(test_spark)
batch_predictions.show()