Поделиться через


Руководство по анализу данных с помощью glm

Сведения о том, как выполнять линейную и логистическую регрессию с помощью обобщенной линейной модели (GLM) в Azure Databricks. glm соответствует обобщенной линейной модели, аналогичной R glm().

Синтаксис: glm(formula, data, family...)

Параметры:

  • formula: символьное описание модели, с которой будет выполняться сопоставление, например ResponseVariable ~ Predictor1 + Predictor2. Поддерживаемые операторы: ~, +, - и ..
  • data: любой кадр данных Spark
  • family: строковое выражение со значением "gaussian" для линейной регрессии или "binomial" для логистической регрессии
  • lambda: числовое значение, которое обозначает параметр регуляризации
  • alpha: числовое значение, которое обозначает параметр смешивания эластичной сети

Выходные данные: модель конвейера MLlib

В этом руководстве вы узнаете, как выполнять линейную и логистическую регрессию в наборе данных Diamonds.

Загрузка данных Diamonds и их разбивка на наборы для обучения и проверки

require(SparkR)

# Read diamonds.csv dataset as SparkDataFrame
diamonds <- read.df("/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv",
                  source = "com.databricks.spark.csv", header="true", inferSchema = "true")
diamonds <- withColumnRenamed(diamonds, "", "rowID")

# Split data into Training set and Test set
trainingData <- sample(diamonds, FALSE, 0.7)
testData <- except(diamonds, trainingData)

# Exclude rowIDs
trainingData <- trainingData[, -1]
testData <- testData[, -1]

print(count(diamonds))
print(count(trainingData))
print(count(testData))
head(trainingData)

Обучение модели линейной регрессии с помощью glm()

В этом разделе показано, как спрогнозировать цену алмазов по их характеристикам, обучив модель линейной регрессии по набору данных для обучения.

Существует сочетание категориальных признаков (вырезание - Идеальный, Премиум, Очень Хороший...) и непрерывные функции (глубина, карат). SparkR автоматически кодирует эти функции, поэтому вам не нужно кодировать эти функции вручную.

# Family = "gaussian" to train a linear regression model
lrModel <- glm(price ~ ., data = trainingData, family = "gaussian")

# Print a summary of the trained model
summary(lrModel)

Используйте predict() для тестовых данных, чтобы оценить качество работы модели с новыми данными.

Синтаксис: predict(model, newData)

Параметры:

  • model: модель MLlib
  • newData: кадр данных Spark, который обычно содержит тестовый набор

Выходные данные: SparkDataFrame

# Generate predictions using the trained model
predictions <- predict(lrModel, newData = testData)

# View predictions against mpg column
display(select(predictions, "price", "prediction"))

Оценка модели.

errors <- select(predictions, predictions$price, predictions$prediction, alias(predictions$price - predictions$prediction, "error"))
display(errors)

# Calculate RMSE
head(select(errors, alias(sqrt(sum(errors$error^2 , na.rm = TRUE) / nrow(errors)), "RMSE")))

Обучение модели логистической регрессии с помощью glm()

В этом разделе показано, как создать логистическую регрессию под одному набору данных, чтобы спрогнозировать огранку алмаза на основе некоторых его характеристик.

Логистическая регрессия в MLlib поддерживает двоичную классификацию. Чтобы протестировать алгоритм в этом примере, подмножество данных для работы с двумя метками.

# Subset data to include rows where diamond cut = "Premium" or diamond cut = "Very Good"
trainingDataSub <- subset(trainingData, trainingData$cut %in% c("Premium", "Very Good"))
testDataSub <- subset(testData, testData$cut %in% c("Premium", "Very Good"))
# Family = "binomial" to train a logistic regression model
logrModel <- glm(cut ~ price + color + clarity + depth, data = trainingDataSub, family = "binomial")

# Print summary of the trained model
summary(logrModel)
# Generate predictions using the trained model
predictionsLogR <- predict(logrModel, newData = testDataSub)

# View predictions against label column
display(select(predictionsLogR, "label", "prediction"))

Оценка модели.

errorsLogR <- select(predictionsLogR, predictionsLogR$label, predictionsLogR$prediction, alias(abs(predictionsLogR$label - predictionsLogR$prediction), "error"))
display(errorsLogR)