diff options
author | Eric Liang <ekl@databricks.com> | 2015-07-15 20:33:06 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-07-15 20:33:06 -0700 |
commit | 6960a7938c61cc07f181ca85e0d8152ceeb453d9 (patch) | |
tree | 9735bcac2f3c0c9645cf142f01213d557a6fe2b6 /mllib | |
parent | b0645195d0da57065885e078e08bd6c42f4f19b0 (diff) | |
download | spark-6960a7938c61cc07f181ca85e0d8152ceeb453d9.tar.gz spark-6960a7938c61cc07f181ca85e0d8152ceeb453d9.tar.bz2 spark-6960a7938c61cc07f181ca85e0d8152ceeb453d9.zip |
[SPARK-8774] [ML] Add R model formula with basic support as a transformer
This implements minimal R formula support as a feature transformer. Both numeric and string labels are supported, but features must be numeric for now.
cc mengxr
Author: Eric Liang <ekl@databricks.com>
Closes #7381 from ericl/spark-8774-1 and squashes the following commits:
d1959d2 [Eric Liang] clarify comment
2db68aa [Eric Liang] second round of comments
dc3c943 [Eric Liang] address comments
5765ec6 [Eric Liang] fix style checks
1f361b0 [Eric Liang] doc
fb0826b [Eric Liang] [SPARK-8774] Add R model formula with basic support as a transformer
Diffstat (limited to 'mllib')
4 files changed, 279 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala new file mode 100644 index 0000000000..d9a36bda38 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.parsing.combinator.RegexParsers + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Implements the transforms required for fitting a dataset against an R model formula. Currently + * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula + * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + */ +@Experimental +class RFormula(override val uid: String) + extends Transformer with HasFeaturesCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("rFormula")) + + /** + * R formula parameter. The formula is provided in string form. + * @group setParam + */ + val formula: Param[String] = new Param(this, "formula", "R model formula") + + private var parsedFormula: Option[ParsedRFormula] = None + + /** + * Sets the formula to use for this transformer. Must be called before use. + * @group setParam + * @param value an R formula in string form (e.g. "y ~ x + z") + */ + def setFormula(value: String): this.type = { + parsedFormula = Some(RFormulaParser.parse(value)) + set(formula, value) + this + } + + /** @group getParam */ + def getFormula: String = $(formula) + + /** @group getParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group getParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + override def transformSchema(schema: StructType): StructType = { + checkCanTransform(schema) + val withFeatures = transformFeatures.transformSchema(schema) + if (hasLabelCol(schema)) { + withFeatures + } else { + val nullable = schema(parsedFormula.get.label).dataType match { + case _: NumericType | BooleanType => false + case _ => true + } + StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable)) + } + } + + override def transform(dataset: DataFrame): DataFrame = { + checkCanTransform(dataset.schema) + transformLabel(transformFeatures.transform(dataset)) + } + + override def copy(extra: ParamMap): RFormula = defaultCopy(extra) + + override def toString: String = s"RFormula(${get(formula)})" + + private def transformLabel(dataset: DataFrame): DataFrame = { + if (hasLabelCol(dataset.schema)) { + dataset + } else { + val labelName = parsedFormula.get.label + dataset.schema(labelName).dataType match { + case _: NumericType | BooleanType => + dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType)) + // TODO(ekl) add support for string-type labels + case other => + throw new IllegalArgumentException("Unsupported type for label: " + other) + } + } + } + + private def transformFeatures: Transformer = { + // TODO(ekl) add support for non-numeric features and feature interactions + new VectorAssembler(uid) + .setInputCols(parsedFormula.get.terms.toArray) + .setOutputCol($(featuresCol)) + } + + private def checkCanTransform(schema: StructType) { + require(parsedFormula.isDefined, "Must call setFormula() first.") + val columnNames = schema.map(_.name) + require(!columnNames.contains($(featuresCol)), "Features column already exists.") + require( + !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, + "Label column already exists and is not of type DoubleType.") + } + + private def hasLabelCol(schema: StructType): Boolean = { + schema.map(_.name).contains($(labelCol)) + } +} + +/** + * Represents a parsed R formula. + */ +private[ml] case class ParsedRFormula(label: String, terms: Seq[String]) + +/** + * Limited implementation of R formula parsing. Currently supports: '~', '+'. + */ +private[ml] object RFormulaParser extends RegexParsers { + def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r + + def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list } + + def formula: Parser[ParsedRFormula] = + (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + + def parse(value: String): ParsedRFormula = parseAll(formula, value) match { + case Success(result, _) => result + case failure: NoSuccess => throw new IllegalArgumentException( + "Could not parse formula: " + value) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 9f83c2ee16..086917fa68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -116,7 +116,7 @@ class VectorAssembler(override val uid: String) if (schema.fieldNames.contains(outputColName)) { throw new IllegalArgumentException(s"Output column $outputColName already exists.") } - StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false)) + StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true)) } override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala new file mode 100644 index 0000000000..c8d065f37a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite + +class RFormulaParserSuite extends SparkFunSuite { + private def checkParse(formula: String, label: String, terms: Seq[String]) { + val parsed = RFormulaParser.parse(formula) + assert(parsed.label == label) + assert(parsed.terms == terms) + } + + test("parse simple formulas") { + checkParse("y ~ x", "y", Seq("x")) + checkParse("y ~ ._foo ", "y", Seq("._foo")) + checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala new file mode 100644 index 0000000000..fa8611b243 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new RFormula()) + } + + test("transform numeric data") { + val formula = new RFormula().setFormula("id ~ v1 + v2") + val original = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val result = formula.transform(original) + val resultSchema = formula.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq( + (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0), + (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0)) + ).toDF("id", "v1", "v2", "features", "label") + // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString + assert(result.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(result.collect().toSeq == expected.collect().toSeq) + } + + test("features column already exists") { + val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + intercept[IllegalArgumentException] { + formula.transformSchema(original.schema) + } + intercept[IllegalArgumentException] { + formula.transform(original) + } + } + + test("label column already exists") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + val resultSchema = formula.transformSchema(original.schema) + assert(resultSchema.length == 3) + assert(resultSchema.toString == formula.transform(original).schema.toString) + } + + test("label column already exists but is not double type") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") + val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") + intercept[IllegalArgumentException] { + formula.transformSchema(original.schema) + } + intercept[IllegalArgumentException] { + formula.transform(original) + } + } + +// TODO(ekl) enable after we implement string label support +// test("transform string label") { +// val formula = new RFormula().setFormula("name ~ id") +// val original = sqlContext.createDataFrame( +// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name") +// val result = formula.transform(original) +// val resultSchema = formula.transformSchema(original.schema) +// val expected = sqlContext.createDataFrame( +// Seq( +// (1, "foo", Vectors.dense(Array(1.0)), 1.0), +// (2, "bar", Vectors.dense(Array(2.0)), 0.0), +// (3, "bar", Vectors.dense(Array(3.0)), 0.0)) +// ).toDF("id", "name", "features", "label") +// assert(result.schema.toString == resultSchema.toString) +// assert(result.collect().toSeq == expected.collect().toSeq) +// } +} |