aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2015-07-28 14:16:57 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-28 14:16:57 -0700
commit8d5bb5283c3cc9180ef34b05be4a715d83073b1e (patch)
treec91d0261b5212032a129bcad4d772f1b183a7ea8
parent6cdcc21fe654ac0a2d0d72783eb10005fc513af6 (diff)
downloadspark-8d5bb5283c3cc9180ef34b05be4a715d83073b1e.tar.gz
spark-8d5bb5283c3cc9180ef34b05be4a715d83073b1e.tar.bz2
spark-8d5bb5283c3cc9180ef34b05be4a715d83073b1e.zip
[SPARK-9391] [ML] Support minus, dot, and intercept operators in SparkR RFormula
Adds '.', '-', and intercept parsing to RFormula. Also splits RFormulaParser into a separate file. Umbrella design doc here: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit?usp=sharing mengxr Author: Eric Liang <ekl@databricks.com> Closes #7707 from ericl/string-features-2 and squashes the following commits: 8588625 [Eric Liang] exclude complex types for . 8106ffe [Eric Liang] comments a9350bb [Eric Liang] s/var/val 9c50d4d [Eric Liang] Merge branch 'string-features' into string-features-2 581afb2 [Eric Liang] Merge branch 'master' into string-features 08ae539 [Eric Liang] Merge branch 'string-features' into string-features-2 f99131a [Eric Liang] comments cecec43 [Eric Liang] Merge branch 'string-features' into string-features-2 0bf3c26 [Eric Liang] update docs 4592df2 [Eric Liang] intercept supports 7412a2e [Eric Liang] Fri Jul 24 14:56:51 PDT 2015 3cf848e [Eric Liang] fix the parser 0556c2b [Eric Liang] Merge branch 'string-features' into string-features-2 c302a2c [Eric Liang] fix tests 9d1ac82 [Eric Liang] Merge remote-tracking branch 'upstream/master' into string-features e713da3 [Eric Liang] comments cd231a9 [Eric Liang] Wed Jul 22 17:18:44 PDT 2015 4d79193 [Eric Liang] revert to seq + distinct 169a085 [Eric Liang] tweak functional test a230a47 [Eric Liang] Merge branch 'master' into string-features 72bd6f3 [Eric Liang] fix merge d841cec [Eric Liang] Merge branch 'master' into string-features 5b2c4a2 [Eric Liang] Mon Jul 20 18:45:33 PDT 2015 b01c7c5 [Eric Liang] add test 8a637db [Eric Liang] encoder wip a1d03f4 [Eric Liang] refactor into estimator
-rw-r--r--R/pkg/R/mllib.R2
-rw-r--r--R/pkg/inst/tests/test_mllib.R8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala52
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala129
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala55
6 files changed, 215 insertions, 41 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 258e354081..6a8bacaa55 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj"))
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
#'
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
-#' operators are supported, including '~' and '+'.
+#' operators are supported, including '~', '+', '-', and '.'.
#' @param data DataFrame for training
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
#' @param lambda Regularization parameter
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 29152a1168..3bef693247 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -40,3 +40,11 @@ test_that("predictions match with native glm", {
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})
+
+test_that("dot minus and intercept vs native glm", {
+ training <- createDataFrame(sqlContext, iris)
+ model <- glm(Sepal_Width ~ . - Species + 0, data = training)
+ vals <- collect(select(predict(model, training), "prediction"))
+ rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+})
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 0a95b1ee8d..0b428d278d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -78,13 +78,20 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
/** @group getParam */
def getFormula: String = $(formula)
+ /** Whether the formula specifies fitting an intercept. */
+ private[ml] def hasIntercept: Boolean = {
+ require(parsedFormula.isDefined, "Must call setFormula() first.")
+ parsedFormula.get.hasIntercept
+ }
+
override def fit(dataset: DataFrame): RFormulaModel = {
require(parsedFormula.isDefined, "Must call setFormula() first.")
+ val resolvedFormula = parsedFormula.get.resolve(dataset.schema)
// StringType terms and terms representing interactions need to be encoded before assembly.
// TODO(ekl) add support for feature interactions
- var encoderStages = ArrayBuffer[PipelineStage]()
- var tempColumns = ArrayBuffer[String]()
- val encodedTerms = parsedFormula.get.terms.map { term =>
+ val encoderStages = ArrayBuffer[PipelineStage]()
+ val tempColumns = ArrayBuffer[String]()
+ val encodedTerms = resolvedFormula.terms.map { term =>
dataset.schema(term) match {
case column if column.dataType == StringType =>
val indexCol = term + "_idx_" + uid
@@ -103,7 +110,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
.setOutputCol($(featuresCol))
encoderStages += new ColumnPruner(tempColumns.toSet)
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
- copyValues(new RFormulaModel(uid, parsedFormula.get, pipelineModel).setParent(this))
+ copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this))
}
// optimistic schema; does not contain any ML attributes
@@ -124,13 +131,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
/**
* :: Experimental ::
* A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
- * @param parsedFormula a pre-parsed R formula.
+ * @param resolvedFormula the fitted R formula.
* @param pipelineModel the fitted feature model, including factor to index mappings.
*/
@Experimental
class RFormulaModel private[feature](
override val uid: String,
- parsedFormula: ParsedRFormula,
+ resolvedFormula: ResolvedRFormula,
pipelineModel: PipelineModel)
extends Model[RFormulaModel] with RFormulaBase {
@@ -144,8 +151,8 @@ class RFormulaModel private[feature](
val withFeatures = pipelineModel.transformSchema(schema)
if (hasLabelCol(schema)) {
withFeatures
- } else if (schema.exists(_.name == parsedFormula.label)) {
- val nullable = schema(parsedFormula.label).dataType match {
+ } else if (schema.exists(_.name == resolvedFormula.label)) {
+ val nullable = schema(resolvedFormula.label).dataType match {
case _: NumericType | BooleanType => false
case _ => true
}
@@ -158,12 +165,12 @@ class RFormulaModel private[feature](
}
override def copy(extra: ParamMap): RFormulaModel = copyValues(
- new RFormulaModel(uid, parsedFormula, pipelineModel))
+ new RFormulaModel(uid, resolvedFormula, pipelineModel))
- override def toString: String = s"RFormulaModel(${parsedFormula})"
+ override def toString: String = s"RFormulaModel(${resolvedFormula})"
private def transformLabel(dataset: DataFrame): DataFrame = {
- val labelName = parsedFormula.label
+ val labelName = resolvedFormula.label
if (hasLabelCol(dataset.schema)) {
dataset
} else if (dataset.schema.exists(_.name == labelName)) {
@@ -207,26 +214,3 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
}
-
-/**
- * Represents a parsed R formula.
- */
-private[ml] case class ParsedRFormula(label: String, terms: Seq[String])
-
-/**
- * Limited implementation of R formula parsing. Currently supports: '~', '+'.
- */
-private[ml] object RFormulaParser extends RegexParsers {
- def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r
-
- def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }
-
- def formula: Parser[ParsedRFormula] =
- (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.distinct) }
-
- def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
- case Success(result, _) => result
- case failure: NoSuccess => throw new IllegalArgumentException(
- "Could not parse formula: " + value)
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
new file mode 100644
index 0000000000..1ca3b92a7d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.util.parsing.combinator.RegexParsers
+
+import org.apache.spark.mllib.linalg.VectorUDT
+import org.apache.spark.sql.types._
+
+/**
+ * Represents a parsed R formula.
+ */
+private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
+ /**
+ * Resolves formula terms into column names. A schema is necessary for inferring the meaning
+ * of the special '.' term. Duplicate terms will be removed during resolution.
+ */
+ def resolve(schema: StructType): ResolvedRFormula = {
+ var includedTerms = Seq[String]()
+ terms.foreach {
+ case Dot =>
+ includedTerms ++= simpleTypes(schema).filter(_ != label.value)
+ case ColumnRef(value) =>
+ includedTerms :+= value
+ case Deletion(term: Term) =>
+ term match {
+ case ColumnRef(value) =>
+ includedTerms = includedTerms.filter(_ != value)
+ case Dot =>
+ // e.g. "- .", which removes all first-order terms
+ val fromSchema = simpleTypes(schema)
+ includedTerms = includedTerms.filter(fromSchema.contains(_))
+ case _: Deletion =>
+ assert(false, "Deletion terms cannot be nested")
+ case _: Intercept =>
+ }
+ case _: Intercept =>
+ }
+ ResolvedRFormula(label.value, includedTerms.distinct)
+ }
+
+ /** Whether this formula specifies fitting with an intercept term. */
+ def hasIntercept: Boolean = {
+ var intercept = true
+ terms.foreach {
+ case Intercept(enabled) =>
+ intercept = enabled
+ case Deletion(Intercept(enabled)) =>
+ intercept = !enabled
+ case _ =>
+ }
+ intercept
+ }
+
+ // the dot operator excludes complex column types
+ private def simpleTypes(schema: StructType): Seq[String] = {
+ schema.fields.filter(_.dataType match {
+ case _: NumericType | StringType | BooleanType | _: VectorUDT => true
+ case _ => false
+ }).map(_.name)
+ }
+}
+
+/**
+ * Represents a fully evaluated and simplified R formula.
+ */
+private[ml] case class ResolvedRFormula(label: String, terms: Seq[String])
+
+/**
+ * R formula terms. See the R formula docs here for more information:
+ * http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
+ */
+private[ml] sealed trait Term
+
+/* R formula reference to all available columns, e.g. "." in a formula */
+private[ml] case object Dot extends Term
+
+/* R formula reference to a column, e.g. "+ Species" in a formula */
+private[ml] case class ColumnRef(value: String) extends Term
+
+/* R formula intercept toggle, e.g. "+ 0" in a formula */
+private[ml] case class Intercept(enabled: Boolean) extends Term
+
+/* R formula deletion of a variable, e.g. "- Species" in a formula */
+private[ml] case class Deletion(term: Term) extends Term
+
+/**
+ * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.'.
+ */
+private[ml] object RFormulaParser extends RegexParsers {
+ def intercept: Parser[Intercept] =
+ "([01])".r ^^ { case a => Intercept(a == "1") }
+
+ def columnRef: Parser[ColumnRef] =
+ "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) }
+
+ def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot }
+
+ def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ {
+ case op ~ list => list.foldLeft(List(op)) {
+ case (left, "+" ~ right) => left ++ Seq(right)
+ case (left, "-" ~ right) => left ++ Seq(Deletion(right))
+ }
+ }
+
+ def formula: Parser[ParsedRFormula] =
+ (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
+
+ def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
+ case Success(result, _) => result
+ case failure: NoSuccess => throw new IllegalArgumentException(
+ "Could not parse formula: " + value)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index 1ee080641e..9f70592cca 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -32,8 +32,14 @@ private[r] object SparkRWrappers {
alpha: Double): PipelineModel = {
val formula = new RFormula().setFormula(value)
val estimator = family match {
- case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha)
- case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha)
+ case "gaussian" => new LinearRegression()
+ .setRegParam(lambda)
+ .setElasticNetParam(alpha)
+ .setFitIntercept(formula.hasIntercept)
+ case "binomial" => new LogisticRegression()
+ .setRegParam(lambda)
+ .setElasticNetParam(alpha)
+ .setFitIntercept(formula.hasIntercept)
}
val pipeline = new Pipeline().setStages(Array(formula, estimator))
pipeline.fit(df)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
index c4b45aee06..436e66bab0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
@@ -18,12 +18,17 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types._
class RFormulaParserSuite extends SparkFunSuite {
- private def checkParse(formula: String, label: String, terms: Seq[String]) {
- val parsed = RFormulaParser.parse(formula)
- assert(parsed.label == label)
- assert(parsed.terms == terms)
+ private def checkParse(
+ formula: String,
+ label: String,
+ terms: Seq[String],
+ schema: StructType = null) {
+ val resolved = RFormulaParser.parse(formula).resolve(schema)
+ assert(resolved.label == label)
+ assert(resolved.terms == terms)
}
test("parse simple formulas") {
@@ -32,4 +37,46 @@ class RFormulaParserSuite extends SparkFunSuite {
checkParse("y ~ ._foo ", "y", Seq("._foo"))
checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
}
+
+ test("parse dot") {
+ val schema = (new StructType)
+ .add("a", "int", true)
+ .add("b", "long", false)
+ .add("c", "string", true)
+ checkParse("a ~ .", "a", Seq("b", "c"), schema)
+ }
+
+ test("parse deletion") {
+ val schema = (new StructType)
+ .add("a", "int", true)
+ .add("b", "long", false)
+ .add("c", "string", true)
+ checkParse("a ~ c - b", "a", Seq("c"), schema)
+ }
+
+ test("parse additions and deletions in order") {
+ val schema = (new StructType)
+ .add("a", "int", true)
+ .add("b", "long", false)
+ .add("c", "string", true)
+ checkParse("a ~ . - b + . - c", "a", Seq("b"), schema)
+ }
+
+ test("dot ignores complex column types") {
+ val schema = (new StructType)
+ .add("a", "int", true)
+ .add("b", "tinyint", false)
+ .add("c", "map<string, string>", true)
+ checkParse("a ~ .", "a", Seq("b"), schema)
+ }
+
+ test("parse intercept") {
+ assert(RFormulaParser.parse("a ~ b").hasIntercept)
+ assert(RFormulaParser.parse("a ~ b + 1").hasIntercept)
+ assert(RFormulaParser.parse("a ~ b - 0").hasIntercept)
+ assert(RFormulaParser.parse("a ~ b - 1 + 1").hasIntercept)
+ assert(!RFormulaParser.parse("a ~ b + 0").hasIntercept)
+ assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept)
+ assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept)
+ }
}