aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2015-07-15 20:33:06 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-15 20:33:06 -0700
commit6960a7938c61cc07f181ca85e0d8152ceeb453d9 (patch)
tree9735bcac2f3c0c9645cf142f01213d557a6fe2b6 /mllib
parentb0645195d0da57065885e078e08bd6c42f4f19b0 (diff)
downloadspark-6960a7938c61cc07f181ca85e0d8152ceeb453d9.tar.gz
spark-6960a7938c61cc07f181ca85e0d8152ceeb453d9.tar.bz2
spark-6960a7938c61cc07f181ca85e0d8152ceeb453d9.zip
[SPARK-8774] [ML] Add R model formula with basic support as a transformer
This implements minimal R formula support as a feature transformer. Both numeric and string labels are supported, but features must be numeric for now. cc mengxr Author: Eric Liang <ekl@databricks.com> Closes #7381 from ericl/spark-8774-1 and squashes the following commits: d1959d2 [Eric Liang] clarify comment 2db68aa [Eric Liang] second round of comments dc3c943 [Eric Liang] address comments 5765ec6 [Eric Liang] fix style checks 1f361b0 [Eric Liang] doc fb0826b [Eric Liang] [SPARK-8774] Add R model formula with basic support as a transformer
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala151
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala34
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala93
4 files changed, 279 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
new file mode 100644
index 0000000000..d9a36bda38
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.util.parsing.combinator.RegexParsers
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.param.{Param, ParamMap}
+import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+/**
+ * :: Experimental ::
+ * Implements the transforms required for fitting a dataset against an R model formula. Currently
+ * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
+ * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
+ */
+@Experimental
+class RFormula(override val uid: String)
+ extends Transformer with HasFeaturesCol with HasLabelCol {
+
+ def this() = this(Identifiable.randomUID("rFormula"))
+
+ /**
+ * R formula parameter. The formula is provided in string form.
+ * @group setParam
+ */
+ val formula: Param[String] = new Param(this, "formula", "R model formula")
+
+ private var parsedFormula: Option[ParsedRFormula] = None
+
+ /**
+ * Sets the formula to use for this transformer. Must be called before use.
+ * @group setParam
+ * @param value an R formula in string form (e.g. "y ~ x + z")
+ */
+ def setFormula(value: String): this.type = {
+ parsedFormula = Some(RFormulaParser.parse(value))
+ set(formula, value)
+ this
+ }
+
+ /** @group getParam */
+ def getFormula: String = $(formula)
+
+ /** @group getParam */
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group getParam */
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ override def transformSchema(schema: StructType): StructType = {
+ checkCanTransform(schema)
+ val withFeatures = transformFeatures.transformSchema(schema)
+ if (hasLabelCol(schema)) {
+ withFeatures
+ } else {
+ val nullable = schema(parsedFormula.get.label).dataType match {
+ case _: NumericType | BooleanType => false
+ case _ => true
+ }
+ StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable))
+ }
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ checkCanTransform(dataset.schema)
+ transformLabel(transformFeatures.transform(dataset))
+ }
+
+ override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
+
+ override def toString: String = s"RFormula(${get(formula)})"
+
+ private def transformLabel(dataset: DataFrame): DataFrame = {
+ if (hasLabelCol(dataset.schema)) {
+ dataset
+ } else {
+ val labelName = parsedFormula.get.label
+ dataset.schema(labelName).dataType match {
+ case _: NumericType | BooleanType =>
+ dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
+ // TODO(ekl) add support for string-type labels
+ case other =>
+ throw new IllegalArgumentException("Unsupported type for label: " + other)
+ }
+ }
+ }
+
+ private def transformFeatures: Transformer = {
+ // TODO(ekl) add support for non-numeric features and feature interactions
+ new VectorAssembler(uid)
+ .setInputCols(parsedFormula.get.terms.toArray)
+ .setOutputCol($(featuresCol))
+ }
+
+ private def checkCanTransform(schema: StructType) {
+ require(parsedFormula.isDefined, "Must call setFormula() first.")
+ val columnNames = schema.map(_.name)
+ require(!columnNames.contains($(featuresCol)), "Features column already exists.")
+ require(
+ !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
+ "Label column already exists and is not of type DoubleType.")
+ }
+
+ private def hasLabelCol(schema: StructType): Boolean = {
+ schema.map(_.name).contains($(labelCol))
+ }
+}
+
+/**
+ * Represents a parsed R formula.
+ */
+private[ml] case class ParsedRFormula(label: String, terms: Seq[String])
+
+/**
+ * Limited implementation of R formula parsing. Currently supports: '~', '+'.
+ */
+private[ml] object RFormulaParser extends RegexParsers {
+ def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r
+
+ def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }
+
+ def formula: Parser[ParsedRFormula] =
+ (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
+
+ def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
+ case Success(result, _) => result
+ case failure: NoSuccess => throw new IllegalArgumentException(
+ "Could not parse formula: " + value)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 9f83c2ee16..086917fa68 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -116,7 +116,7 @@ class VectorAssembler(override val uid: String)
if (schema.fieldNames.contains(outputColName)) {
throw new IllegalArgumentException(s"Output column $outputColName already exists.")
}
- StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
+ StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true))
}
override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
new file mode 100644
index 0000000000..c8d065f37a
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+
+class RFormulaParserSuite extends SparkFunSuite {
+ private def checkParse(formula: String, label: String, terms: Seq[String]) {
+ val parsed = RFormulaParser.parse(formula)
+ assert(parsed.label == label)
+ assert(parsed.terms == terms)
+ }
+
+ test("parse simple formulas") {
+ checkParse("y ~ x", "y", Seq("x"))
+ checkParse("y ~ ._foo ", "y", Seq("._foo"))
+ checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
new file mode 100644
index 0000000000..fa8611b243
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new RFormula())
+ }
+
+ test("transform numeric data") {
+ val formula = new RFormula().setFormula("id ~ v1 + v2")
+ val original = sqlContext.createDataFrame(
+ Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
+ val result = formula.transform(original)
+ val resultSchema = formula.transformSchema(original.schema)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0),
+ (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0))
+ ).toDF("id", "v1", "v2", "features", "label")
+ // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
+ assert(result.schema.toString == resultSchema.toString)
+ assert(resultSchema == expected.schema)
+ assert(result.collect().toSeq == expected.collect().toSeq)
+ }
+
+ test("features column already exists") {
+ val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
+ val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
+ intercept[IllegalArgumentException] {
+ formula.transformSchema(original.schema)
+ }
+ intercept[IllegalArgumentException] {
+ formula.transform(original)
+ }
+ }
+
+ test("label column already exists") {
+ val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
+ val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
+ val resultSchema = formula.transformSchema(original.schema)
+ assert(resultSchema.length == 3)
+ assert(resultSchema.toString == formula.transform(original).schema.toString)
+ }
+
+ test("label column already exists but is not double type") {
+ val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
+ val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
+ intercept[IllegalArgumentException] {
+ formula.transformSchema(original.schema)
+ }
+ intercept[IllegalArgumentException] {
+ formula.transform(original)
+ }
+ }
+
+// TODO(ekl) enable after we implement string label support
+// test("transform string label") {
+// val formula = new RFormula().setFormula("name ~ id")
+// val original = sqlContext.createDataFrame(
+// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name")
+// val result = formula.transform(original)
+// val resultSchema = formula.transformSchema(original.schema)
+// val expected = sqlContext.createDataFrame(
+// Seq(
+// (1, "foo", Vectors.dense(Array(1.0)), 1.0),
+// (2, "bar", Vectors.dense(Array(2.0)), 0.0),
+// (3, "bar", Vectors.dense(Array(3.0)), 0.0))
+// ).toDF("id", "name", "features", "label")
+// assert(result.schema.toString == resultSchema.toString)
+// assert(result.collect().toSeq == expected.collect().toSeq)
+// }
+}