aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/inst/tests/test_mllib.R6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala133
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala64
4 files changed, 142 insertions, 62 deletions
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index a492763344..29152a1168 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -35,8 +35,8 @@ test_that("glm and predict", {
test_that("predictions match with native glm", {
training <- createDataFrame(sqlContext, iris)
- model <- glm(Sepal_Width ~ Sepal_Length, data = training)
+ model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
vals <- collect(select(predict(model, training), "prediction"))
- rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris)
- expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals)
+ rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index f7b46efa10..0a95b1ee8d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -17,26 +17,42 @@
package org.apache.spark.ml.feature
+import scala.collection.mutable.ArrayBuffer
import scala.util.parsing.combinator.RegexParsers
import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
/**
+ * Base trait for [[RFormula]] and [[RFormulaModel]].
+ */
+private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
+ /** @group getParam */
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group getParam */
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ protected def hasLabelCol(schema: StructType): Boolean = {
+ schema.map(_.name).contains($(labelCol))
+ }
+}
+
+/**
* :: Experimental ::
* Implements the transforms required for fitting a dataset against an R model formula. Currently
* we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
* docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
*/
@Experimental
-class RFormula(override val uid: String)
- extends Transformer with HasFeaturesCol with HasLabelCol {
+class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase {
def this() = this(Identifiable.randomUID("rFormula"))
@@ -62,19 +78,74 @@ class RFormula(override val uid: String)
/** @group getParam */
def getFormula: String = $(formula)
- /** @group getParam */
- def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+ override def fit(dataset: DataFrame): RFormulaModel = {
+ require(parsedFormula.isDefined, "Must call setFormula() first.")
+ // StringType terms and terms representing interactions need to be encoded before assembly.
+ // TODO(ekl) add support for feature interactions
+ var encoderStages = ArrayBuffer[PipelineStage]()
+ var tempColumns = ArrayBuffer[String]()
+ val encodedTerms = parsedFormula.get.terms.map { term =>
+ dataset.schema(term) match {
+ case column if column.dataType == StringType =>
+ val indexCol = term + "_idx_" + uid
+ val encodedCol = term + "_onehot_" + uid
+ encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol)
+ encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
+ tempColumns += indexCol
+ tempColumns += encodedCol
+ encodedCol
+ case _ =>
+ term
+ }
+ }
+ encoderStages += new VectorAssembler(uid)
+ .setInputCols(encodedTerms.toArray)
+ .setOutputCol($(featuresCol))
+ encoderStages += new ColumnPruner(tempColumns.toSet)
+ val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
+ copyValues(new RFormulaModel(uid, parsedFormula.get, pipelineModel).setParent(this))
+ }
- /** @group getParam */
- def setLabelCol(value: String): this.type = set(labelCol, value)
+ // optimistic schema; does not contain any ML attributes
+ override def transformSchema(schema: StructType): StructType = {
+ if (hasLabelCol(schema)) {
+ StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
+ } else {
+ StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true) :+
+ StructField($(labelCol), DoubleType, true))
+ }
+ }
+
+ override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
+
+ override def toString: String = s"RFormula(${get(formula)})"
+}
+
+/**
+ * :: Experimental ::
+ * A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
+ * @param parsedFormula a pre-parsed R formula.
+ * @param pipelineModel the fitted feature model, including factor to index mappings.
+ */
+@Experimental
+class RFormulaModel private[feature](
+ override val uid: String,
+ parsedFormula: ParsedRFormula,
+ pipelineModel: PipelineModel)
+ extends Model[RFormulaModel] with RFormulaBase {
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ checkCanTransform(dataset.schema)
+ transformLabel(pipelineModel.transform(dataset))
+ }
override def transformSchema(schema: StructType): StructType = {
checkCanTransform(schema)
- val withFeatures = transformFeatures.transformSchema(schema)
+ val withFeatures = pipelineModel.transformSchema(schema)
if (hasLabelCol(schema)) {
withFeatures
- } else if (schema.exists(_.name == parsedFormula.get.label)) {
- val nullable = schema(parsedFormula.get.label).dataType match {
+ } else if (schema.exists(_.name == parsedFormula.label)) {
+ val nullable = schema(parsedFormula.label).dataType match {
case _: NumericType | BooleanType => false
case _ => true
}
@@ -86,24 +157,19 @@ class RFormula(override val uid: String)
}
}
- override def transform(dataset: DataFrame): DataFrame = {
- checkCanTransform(dataset.schema)
- transformLabel(transformFeatures.transform(dataset))
- }
-
- override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
+ override def copy(extra: ParamMap): RFormulaModel = copyValues(
+ new RFormulaModel(uid, parsedFormula, pipelineModel))
- override def toString: String = s"RFormula(${get(formula)})"
+ override def toString: String = s"RFormulaModel(${parsedFormula})"
private def transformLabel(dataset: DataFrame): DataFrame = {
- val labelName = parsedFormula.get.label
+ val labelName = parsedFormula.label
if (hasLabelCol(dataset.schema)) {
dataset
} else if (dataset.schema.exists(_.name == labelName)) {
dataset.schema(labelName).dataType match {
case _: NumericType | BooleanType =>
dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
- // TODO(ekl) add support for string-type labels
case other =>
throw new IllegalArgumentException("Unsupported type for label: " + other)
}
@@ -114,25 +180,32 @@ class RFormula(override val uid: String)
}
}
- private def transformFeatures: Transformer = {
- // TODO(ekl) add support for non-numeric features and feature interactions
- new VectorAssembler(uid)
- .setInputCols(parsedFormula.get.terms.toArray)
- .setOutputCol($(featuresCol))
- }
-
private def checkCanTransform(schema: StructType) {
- require(parsedFormula.isDefined, "Must call setFormula() first.")
val columnNames = schema.map(_.name)
require(!columnNames.contains($(featuresCol)), "Features column already exists.")
require(
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
"Label column already exists and is not of type DoubleType.")
}
+}
- private def hasLabelCol(schema: StructType): Boolean = {
- schema.map(_.name).contains($(labelCol))
+/**
+ * Utility transformer for removing temporary columns from a DataFrame.
+ * TODO(ekl) make this a public transformer
+ */
+private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
+ override val uid = Identifiable.randomUID("columnPruner")
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
+ dataset.select(columnsToKeep.map(dataset.col) : _*)
}
+
+ override def transformSchema(schema: StructType): StructType = {
+ StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name)))
+ }
+
+ override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
}
/**
@@ -149,7 +222,7 @@ private[ml] object RFormulaParser extends RegexParsers {
def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }
def formula: Parser[ParsedRFormula] =
- (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
+ (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.distinct) }
def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
case Success(result, _) => result
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
index c8d065f37a..c4b45aee06 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
@@ -28,6 +28,7 @@ class RFormulaParserSuite extends SparkFunSuite {
test("parse simple formulas") {
checkParse("y ~ x", "y", Seq("x"))
+ checkParse("y ~ x + x", "y", Seq("x"))
checkParse("y ~ ._foo ", "y", Seq("._foo"))
checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 79c4ccf02d..8148c553e9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -31,72 +31,78 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
val formula = new RFormula().setFormula("id ~ v1 + v2")
val original = sqlContext.createDataFrame(
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
- val result = formula.transform(original)
- val resultSchema = formula.transformSchema(original.schema)
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val resultSchema = model.transformSchema(original.schema)
val expected = sqlContext.createDataFrame(
Seq(
- (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0),
- (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0))
+ (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0),
+ (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0))
).toDF("id", "v1", "v2", "features", "label")
// TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
assert(result.schema.toString == resultSchema.toString)
assert(resultSchema == expected.schema)
- assert(result.collect().toSeq == expected.collect().toSeq)
+ assert(result.collect() === expected.collect())
}
test("features column already exists") {
val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
intercept[IllegalArgumentException] {
- formula.transformSchema(original.schema)
+ formula.fit(original)
}
intercept[IllegalArgumentException] {
- formula.transform(original)
+ formula.fit(original)
}
}
test("label column already exists") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
- val resultSchema = formula.transformSchema(original.schema)
+ val model = formula.fit(original)
+ val resultSchema = model.transformSchema(original.schema)
assert(resultSchema.length == 3)
- assert(resultSchema.toString == formula.transform(original).schema.toString)
+ assert(resultSchema.toString == model.transform(original).schema.toString)
}
test("label column already exists but is not double type") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
+ val model = formula.fit(original)
intercept[IllegalArgumentException] {
- formula.transformSchema(original.schema)
+ model.transformSchema(original.schema)
}
intercept[IllegalArgumentException] {
- formula.transform(original)
+ model.transform(original)
}
}
test("allow missing label column for test datasets") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("label")
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
- val resultSchema = formula.transformSchema(original.schema)
+ val model = formula.fit(original)
+ val resultSchema = model.transformSchema(original.schema)
assert(resultSchema.length == 3)
assert(!resultSchema.exists(_.name == "label"))
- assert(resultSchema.toString == formula.transform(original).schema.toString)
+ assert(resultSchema.toString == model.transform(original).schema.toString)
}
-// TODO(ekl) enable after we implement string label support
-// test("transform string label") {
-// val formula = new RFormula().setFormula("name ~ id")
-// val original = sqlContext.createDataFrame(
-// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name")
-// val result = formula.transform(original)
-// val resultSchema = formula.transformSchema(original.schema)
-// val expected = sqlContext.createDataFrame(
-// Seq(
-// (1, "foo", Vectors.dense(Array(1.0)), 1.0),
-// (2, "bar", Vectors.dense(Array(2.0)), 0.0),
-// (3, "bar", Vectors.dense(Array(3.0)), 0.0))
-// ).toDF("id", "name", "features", "label")
-// assert(result.schema.toString == resultSchema.toString)
-// assert(result.collect().toSeq == expected.collect().toSeq)
-// }
+ test("encodes string terms") {
+ val formula = new RFormula().setFormula("id ~ a + b")
+ val original = sqlContext.createDataFrame(
+ Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
+ ).toDF("id", "a", "b")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val resultSchema = model.transformSchema(original.schema)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
+ (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
+ (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
+ (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0))
+ ).toDF("id", "a", "b", "features", "label")
+ assert(result.schema.toString == resultSchema.toString)
+ assert(result.collect() === expected.collect())
+ }
}