aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2015-07-30 16:15:43 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-30 16:15:43 -0700
commite7905a9395c1a002f50bab29e16a729e14d4ed6f (patch)
tree37758d36fd51f330ca7b4ce2b9f9bb47784a2dcb
parentbe7be6d4c7d978c20e601d1f5f56ecb3479814cb (diff)
downloadspark-e7905a9395c1a002f50bab29e16a729e14d4ed6f.tar.gz
spark-e7905a9395c1a002f50bab29e16a729e14d4ed6f.tar.bz2
spark-e7905a9395c1a002f50bab29e16a729e14d4ed6f.zip
[SPARK-9463] [ML] Expose model coefficients with names in SparkR RFormula
Preview: ``` > summary(m) features coefficients 1 (Intercept) 1.6765001 2 Sepal_Length 0.3498801 3 Species.versicolor -0.9833885 4 Species.virginica -1.0075104 ``` Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit cc mengxr Author: Eric Liang <ekl@databricks.com> Closes #7771 from ericl/summary and squashes the following commits: ccd54c3 [Eric Liang] second pass a5ca93b [Eric Liang] comments 2772111 [Eric Liang] clean up 70483ef [Eric Liang] fix test 7c247d4 [Eric Liang] Merge branch 'master' into summary 3c55024 [Eric Liang] working 8c539aa [Eric Liang] first pass
-rw-r--r--R/pkg/NAMESPACE3
-rw-r--r--R/pkg/R/mllib.R26
-rw-r--r--R/pkg/inst/tests/test_mllib.R11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala27
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala18
9 files changed, 108 insertions, 17 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 7f7a8a2e4d..a329e14f25 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -12,7 +12,8 @@ export("print.jobj")
# MLlib integration
exportMethods("glm",
- "predict")
+ "predict",
+ "summary")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 6a8bacaa55..efddcc1d8d 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
})
+
+#' Get the summary of a model
+#'
+#' Returns the summary of a model produced by glm(), similarly to R's summary().
+#'
+#' @param model A fitted MLlib model
+#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See
+#' summary.glm for more information.
+#' @rdname glm
+#' @export
+#' @examples
+#'\dontrun{
+#' model <- glm(y ~ x, trainingData)
+#' summary(model)
+#'}
+setMethod("summary", signature(object = "PipelineModel"),
+ function(object) {
+ features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getModelFeatures", object@model)
+ weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getModelWeights", object@model)
+ coefficients <- as.matrix(unlist(weights))
+ colnames(coefficients) <- c("Estimate")
+ rownames(coefficients) <- unlist(features)
+ return(list(coefficients = coefficients))
+ })
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 3bef693247..f272de78ad 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -48,3 +48,14 @@ test_that("dot minus and intercept vs native glm", {
rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})
+
+test_that("summary coefficients match with native glm", {
+ training <- createDataFrame(sqlContext, iris)
+ stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
+ coefs <- as.vector(stats$coefficients)
+ rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
+ expect_true(all(abs(rCoefs - coefs) < 1e-6))
+ expect_true(all(
+ as.character(stats$features) ==
+ c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
+})
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index 3825942795..9c60d4084e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transformSchema(schema: StructType): StructType = {
- val is = "_is_"
val inputColName = $(inputCol)
val outputColName = $(outputCol)
@@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer
val outputAttrNames: Option[Array[String]] = inputAttr match {
case nominal: NominalAttribute =>
if (nominal.values.isDefined) {
- nominal.values.map(_.map(v => inputColName + is + v))
+ nominal.values
} else if (nominal.numValues.isDefined) {
- nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i))
+ nominal.numValues.map(n => Array.tabulate(n)(_.toString))
} else {
None
}
case binary: BinaryAttribute =>
if (binary.values.isDefined) {
- binary.values.map(_.map(v => inputColName + is + v))
+ binary.values
} else {
- Some(Array.tabulate(2)(i => inputColName + is + i))
+ Some(Array.tabulate(2)(_.toString))
}
case _: NumericAttribute =>
throw new RuntimeException(
@@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
override def transform(dataset: DataFrame): DataFrame = {
// schema transformation
- val is = "_is_"
val inputColName = $(inputCol)
val outputColName = $(outputCol)
val shouldDropLast = $(dropLast)
@@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer
math.max(m0, m1)
}
).toInt + 1
- val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i)
+ val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
val outputAttrs: Array[Attribute] =
filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 0b428d278d..d1726917e4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.parsing.combinator.RegexParsers
@@ -91,11 +92,20 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
// TODO(ekl) add support for feature interactions
val encoderStages = ArrayBuffer[PipelineStage]()
val tempColumns = ArrayBuffer[String]()
+ val takenNames = mutable.Set(dataset.columns: _*)
val encodedTerms = resolvedFormula.terms.map { term =>
dataset.schema(term) match {
case column if column.dataType == StringType =>
val indexCol = term + "_idx_" + uid
- val encodedCol = term + "_onehot_" + uid
+ val encodedCol = {
+ var tmp = term
+ while (takenNames.contains(tmp)) {
+ tmp += "_"
+ }
+ tmp
+ }
+ takenNames.add(indexCol)
+ takenNames.add(encodedCol)
encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol)
encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
tempColumns += indexCol
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index 9f70592cca..f5a022c31e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -17,9 +17,10 @@
package org.apache.spark.ml.api.r
+import org.apache.spark.ml.attribute._
import org.apache.spark.ml.feature.RFormula
-import org.apache.spark.ml.classification.LogisticRegression
-import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
+import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.DataFrame
@@ -44,4 +45,26 @@ private[r] object SparkRWrappers {
val pipeline = new Pipeline().setStages(Array(formula, estimator))
pipeline.fit(df)
}
+
+ def getModelWeights(model: PipelineModel): Array[Double] = {
+ model.stages.last match {
+ case m: LinearRegressionModel =>
+ Array(m.intercept) ++ m.weights.toArray
+ case _: LogisticRegressionModel =>
+ throw new UnsupportedOperationException(
+ "No weights available for LogisticRegressionModel") // SPARK-9492
+ }
+ }
+
+ def getModelFeatures(model: PipelineModel): Array[String] = {
+ model.stages.last match {
+ case m: LinearRegressionModel =>
+ val attrs = AttributeGroup.fromStructField(
+ m.summary.predictions.schema(m.summary.featuresCol))
+ Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+ case _: LogisticRegressionModel =>
+ throw new UnsupportedOperationException(
+ "No features names available for LogisticRegressionModel") // SPARK-9492
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 89718e0f3e..3b85ba001b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -36,6 +36,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.StructField
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.StatCounter
@@ -146,9 +147,10 @@ class LinearRegression(override val uid: String)
val model = new LinearRegressionModel(uid, weights, intercept)
val trainingSummary = new LinearRegressionTrainingSummary(
- model.transform(dataset).select($(predictionCol), $(labelCol)),
+ model.transform(dataset),
$(predictionCol),
$(labelCol),
+ $(featuresCol),
Array(0D))
return copyValues(model.setSummary(trainingSummary))
}
@@ -221,9 +223,10 @@ class LinearRegression(override val uid: String)
val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
val trainingSummary = new LinearRegressionTrainingSummary(
- model.transform(dataset).select($(predictionCol), $(labelCol)),
+ model.transform(dataset),
$(predictionCol),
$(labelCol),
+ $(featuresCol),
objectiveHistory)
model.setSummary(trainingSummary)
}
@@ -300,6 +303,7 @@ class LinearRegressionTrainingSummary private[regression] (
predictions: DataFrame,
predictionCol: String,
labelCol: String,
+ val featuresCol: String,
val objectiveHistory: Array[Double])
extends LinearRegressionSummary(predictions, predictionCol, labelCol) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 65846a846b..321eeb8439 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
val output = encoder.transform(df)
val group = AttributeGroup.fromStructField(output.schema("encoded"))
assert(group.size === 2)
- assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0))
- assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1))
+ assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
+ assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
}
test("input column without ML attribute") {
@@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
val output = encoder.transform(df)
val group = AttributeGroup.fromStructField(output.schema("encoded"))
assert(group.size === 2)
- assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0))
- assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1))
+ assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
+ assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 8148c553e9..6aed3243af 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -105,4 +106,21 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(result.schema.toString == resultSchema.toString)
assert(result.collect() === expected.collect())
}
+
+ test("attribute generation") {
+ val formula = new RFormula().setFormula("id ~ a + b")
+ val original = sqlContext.createDataFrame(
+ Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
+ ).toDF("id", "a", "b")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array(
+ new BinaryAttribute(Some("a__bar"), Some(1)),
+ new BinaryAttribute(Some("a__foo"), Some(2)),
+ new NumericAttribute(Some("b"), Some(3))))
+ assert(attrs === expectedAttrs)
+ }
}