From 922338812c03eba43f2f1a6c414d1b6b049811cf Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 25 Sep 2015 00:43:22 -0700 Subject: [SPARK-9681] [ML] Support R feature interactions in RFormula This integrates the Interaction feature transformer with SparkR R formula support (i.e. support `:`). To generate reasonable ML attribute names for feature interactions, it was necessary to add the ability to read attribute the original attribute names back from `StructField`, and also to specify custom group prefixes in `VectorAssembler`. This also has the side-benefit of cleaning up the double-underscores in the attributes generated for non-interaction terms. mengxr Author: Eric Liang Closes #8830 from ericl/interaction-2. --- R/pkg/R/mllib.R | 2 +- R/pkg/inst/tests/test_mllib.R | 10 +- .../org/apache/spark/ml/attribute/attributes.scala | 16 ++- .../org/apache/spark/ml/feature/Interaction.scala | 12 ++- .../org/apache/spark/ml/feature/RFormula.scala | 113 +++++++++++++++++---- .../apache/spark/ml/feature/RFormulaParser.scala | 97 ++++++++++++++---- .../apache/spark/ml/feature/StringIndexer.scala | 5 +- .../spark/ml/feature/RFormulaParserSuite.scala | 89 +++++++++++++++- .../apache/spark/ml/feature/RFormulaSuite.scala | 76 +++++++++++++- python/pyspark/ml/feature.py | 2 +- 10 files changed, 362 insertions(+), 60 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index cea3d760d0..474ada5956 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj")) #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. #' #' @param formula A symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '+', '-', and '.'. +#' operators are supported, including '~', '.', ':', '+', and '-'. #' @param data DataFrame for training #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. #' @param lambda Regularization parameter diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index f272de78ad..032f8ec68b 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -49,6 +49,14 @@ test_that("dot minus and intercept vs native glm", { expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) +test_that("feature interaction vs native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + test_that("summary coefficients match with native glm", { training <- createDataFrame(sqlContext, iris) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) @@ -57,5 +65,5 @@ test_that("summary coefficients match with native glm", { expect_true(all(abs(rCoefs - coefs) < 1e-6)) expect_true(all( as.character(stats$features) == - c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) + c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) }) diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index e479f16902..a7c10333c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -124,18 +124,28 @@ private[attribute] trait AttributeFactory { private[attribute] def fromMetadata(metadata: Metadata): Attribute /** - * Creates an [[Attribute]] from a [[StructField]] instance. + * Creates an [[Attribute]] from a [[StructField]] instance, optionally preserving name. */ - def fromStructField(field: StructField): Attribute = { + private[ml] def decodeStructField(field: StructField, preserveName: Boolean): Attribute = { require(field.dataType.isInstanceOf[NumericType]) val metadata = field.metadata val mlAttr = AttributeKeys.ML_ATTR if (metadata.contains(mlAttr)) { - fromMetadata(metadata.getMetadata(mlAttr)).withName(field.name) + val attr = fromMetadata(metadata.getMetadata(mlAttr)) + if (preserveName) { + attr + } else { + attr.withName(field.name) + } } else { UnresolvedAttribute } } + + /** + * Creates an [[Attribute]] from a [[StructField]] instance. + */ + def fromStructField(field: StructField): Attribute = decodeStructField(field, false) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 9194763fb3..37f7862476 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -149,8 +149,14 @@ class Interaction(override val uid: String) extends Transformer features.reverse.foreach { f => val encodedAttrs = f.dataType match { case _: NumericType | BooleanType => - val attr = Attribute.fromStructField(f) - encodedFeatureAttrs(Seq(attr), None) + val attr = Attribute.decodeStructField(f, preserveName = true) + if (attr == UnresolvedAttribute) { + encodedFeatureAttrs(Seq(NumericAttribute.defaultAttr.withName(f.name)), None) + } else if (!attr.name.isDefined) { + encodedFeatureAttrs(Seq(attr.withName(f.name)), None) + } else { + encodedFeatureAttrs(Seq(attr), None) + } case _: VectorUDT => val group = AttributeGroup.fromStructField(f) encodedFeatureAttrs(group.attributes.get, Some(group.name)) @@ -221,7 +227,7 @@ class Interaction(override val uid: String) extends Transformer * count is equal to the number of categories. For numeric features the count * should be set to 1. */ -private[ml] class FeatureEncoder(numFeatures: Array[Int]) { +private[ml] class FeatureEncoder(numFeatures: Array[Int]) extends Serializable { assert(numFeatures.forall(_ > 0), "Features counts must all be positive.") /** The size of the output vector. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index dcd6fe3c40..f9b840097f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} @@ -42,8 +43,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { /** * :: Experimental :: * Implements the transforms required for fitting a dataset against an R model formula. Currently - * we support a limited subset of the R operators, including '.', '~', '+', and '-'. Also see the - * R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + * we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. Also see + * the R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html */ @Experimental class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { @@ -82,36 +83,54 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R require(isDefined(formula), "Formula must be defined first.") val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) - // StringType terms and terms representing interactions need to be encoded before assembly. - // TODO(ekl) add support for feature interactions val encoderStages = ArrayBuffer[PipelineStage]() + + val prefixesToRewrite = mutable.Map[String, String]() val tempColumns = ArrayBuffer[String]() - val takenNames = mutable.Set(dataset.columns: _*) - val encodedTerms = resolvedFormula.terms.map { term => + def tmpColumn(category: String): String = { + val col = Identifiable.randomUID(category) + tempColumns += col + col + } + + // First we index each string column referenced by the input terms. + val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term => dataset.schema(term) match { case column if column.dataType == StringType => - val indexCol = term + "_idx_" + uid - val encodedCol = { - var tmp = term - while (takenNames.contains(tmp)) { - tmp += "_" - } - tmp - } - takenNames.add(indexCol) - takenNames.add(encodedCol) - encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol) - encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol) - tempColumns += indexCol - tempColumns += encodedCol - encodedCol + val indexCol = tmpColumn("stridx") + encoderStages += new StringIndexer() + .setInputCol(term) + .setOutputCol(indexCol) + (term, indexCol) case _ => - term + (term, term) } + }.toMap + + // Then we handle one-hot encoding and interactions between terms. + val encodedTerms = resolvedFormula.terms.map { + case Seq(term) if dataset.schema(term).dataType == StringType => + val encodedCol = tmpColumn("onehot") + encoderStages += new OneHotEncoder() + .setInputCol(indexed(term)) + .setOutputCol(encodedCol) + prefixesToRewrite(encodedCol + "_") = term + "_" + encodedCol + case Seq(term) => + term + case terms => + val interactionCol = tmpColumn("interaction") + encoderStages += new Interaction() + .setInputCols(terms.map(indexed).toArray) + .setOutputCol(interactionCol) + prefixesToRewrite(interactionCol + "_") = "" + interactionCol } + encoderStages += new VectorAssembler(uid) .setInputCols(encodedTerms.toArray) .setOutputCol($(featuresCol)) + encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap) encoderStages += new ColumnPruner(tempColumns.toSet) val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this)) @@ -218,3 +237,53 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) } + +/** + * Utility transformer that rewrites Vector attribute names via prefix replacement. For example, + * it can rewrite attribute names starting with 'foo_' to start with 'bar_' instead. + * + * @param vectorCol name of the vector column to rewrite. + * @param prefixesToRewrite the map of string prefixes to their replacement values. Each attribute + * name defined in vectorCol will be checked against the keys of this + * map. When a key prefixes a name, the matching prefix will be replaced + * by the value in the map. + */ +private class VectorAttributeRewriter( + vectorCol: String, + prefixesToRewrite: Map[String, String]) + extends Transformer { + + override val uid = Identifiable.randomUID("vectorAttrRewriter") + + override def transform(dataset: DataFrame): DataFrame = { + val metadata = { + val group = AttributeGroup.fromStructField(dataset.schema(vectorCol)) + val attrs = group.attributes.get.map { attr => + if (attr.name.isDefined) { + val name = attr.name.get + val replacement = prefixesToRewrite.filter { case (k, _) => name.startsWith(k) } + if (replacement.nonEmpty) { + val (k, v) = replacement.headOption.get + attr.withName(v + name.stripPrefix(k)) + } else { + attr + } + } else { + attr + } + } + new AttributeGroup(vectorCol, attrs).toMetadata() + } + val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col) + val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata) + dataset.select((otherCols :+ rewrittenCol): _*) + } + + override def transformSchema(schema: StructType): StructType = { + StructType( + schema.fields.filter(_.name != vectorCol) ++ + schema.fields.filter(_.name == vectorCol)) + } + + override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index 1ca3b92a7d..4079b387e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import scala.collection.mutable import scala.util.parsing.combinator.RegexParsers import org.apache.spark.mllib.linalg.VectorUDT @@ -31,27 +32,35 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { * of the special '.' term. Duplicate terms will be removed during resolution. */ def resolve(schema: StructType): ResolvedRFormula = { - var includedTerms = Seq[String]() + val dotTerms = expandDot(schema) + var includedTerms = Seq[Seq[String]]() terms.foreach { + case col: ColumnRef => + includedTerms :+= Seq(col.value) + case ColumnInteraction(cols) => + includedTerms ++= expandInteraction(schema, cols) case Dot => - includedTerms ++= simpleTypes(schema).filter(_ != label.value) - case ColumnRef(value) => - includedTerms :+= value + includedTerms ++= dotTerms.map(Seq(_)) case Deletion(term: Term) => term match { - case ColumnRef(value) => - includedTerms = includedTerms.filter(_ != value) + case inner: ColumnRef => + includedTerms = includedTerms.filter(_ != Seq(inner.value)) + case ColumnInteraction(cols) => + val fromInteraction = expandInteraction(schema, cols).map(_.toSet) + includedTerms = includedTerms.filter(t => !fromInteraction.contains(t.toSet)) case Dot => // e.g. "- .", which removes all first-order terms - val fromSchema = simpleTypes(schema) - includedTerms = includedTerms.filter(fromSchema.contains(_)) + includedTerms = includedTerms.filter { + case Seq(t) => !dotTerms.contains(t) + case _ => true + } case _: Deletion => - assert(false, "Deletion terms cannot be nested") + throw new RuntimeException("Deletion terms cannot be nested") case _: Intercept => } case _: Intercept => } - ResolvedRFormula(label.value, includedTerms.distinct) + ResolvedRFormula(label.value, includedTerms.distinct, hasIntercept) } /** Whether this formula specifies fitting with an intercept term. */ @@ -67,19 +76,54 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { intercept } + // expands the Dot operators in interaction terms + private def expandInteraction( + schema: StructType, terms: Seq[InteractableTerm]): Seq[Seq[String]] = { + if (terms.isEmpty) { + return Seq(Nil) + } + + val rest = expandInteraction(schema, terms.tail) + val validInteractions = (terms.head match { + case Dot => + expandDot(schema).flatMap { t => + rest.map { r => + Seq(t) ++ r + } + } + case ColumnRef(value) => + rest.map(Seq(value) ++ _) + }).map(_.distinct) + + // Deduplicates feature interactions, for example, a:b is the same as b:a. + var seen = mutable.Set[Set[String]]() + validInteractions.flatMap { + case t if seen.contains(t.toSet) => + None + case t => + seen += t.toSet + Some(t) + }.sortBy(_.length) + } + // the dot operator excludes complex column types - private def simpleTypes(schema: StructType): Seq[String] = { + private def expandDot(schema: StructType): Seq[String] = { schema.fields.filter(_.dataType match { case _: NumericType | StringType | BooleanType | _: VectorUDT => true case _ => false - }).map(_.name) + }).map(_.name).filter(_ != label.value) } } /** * Represents a fully evaluated and simplified R formula. + * @param label the column name of the R formula label (response variable). + * @param terms the simplified terms of the R formula. Interactions terms are represented as Seqs + * of column names; non-interaction terms as length 1 Seqs. + * @param hasIntercept whether the formula specifies fitting with an intercept. */ -private[ml] case class ResolvedRFormula(label: String, terms: Seq[String]) +private[ml] case class ResolvedRFormula( + label: String, terms: Seq[Seq[String]], hasIntercept: Boolean) /** * R formula terms. See the R formula docs here for more information: @@ -87,11 +131,17 @@ private[ml] case class ResolvedRFormula(label: String, terms: Seq[String]) */ private[ml] sealed trait Term +/** A term that may be part of an interaction, e.g. 'x' in 'x:y' */ +private[ml] sealed trait InteractableTerm extends Term + /* R formula reference to all available columns, e.g. "." in a formula */ -private[ml] case object Dot extends Term +private[ml] case object Dot extends InteractableTerm /* R formula reference to a column, e.g. "+ Species" in a formula */ -private[ml] case class ColumnRef(value: String) extends Term +private[ml] case class ColumnRef(value: String) extends InteractableTerm + +/* R formula interaction of several columns, e.g. "Sepal_Length:Species" in a formula */ +private[ml] case class ColumnInteraction(terms: Seq[InteractableTerm]) extends Term /* R formula intercept toggle, e.g. "+ 0" in a formula */ private[ml] case class Intercept(enabled: Boolean) extends Term @@ -100,25 +150,30 @@ private[ml] case class Intercept(enabled: Boolean) extends Term private[ml] case class Deletion(term: Term) extends Term /** - * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.'. + * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.', ':'. */ private[ml] object RFormulaParser extends RegexParsers { - def intercept: Parser[Intercept] = + private val intercept: Parser[Intercept] = "([01])".r ^^ { case a => Intercept(a == "1") } - def columnRef: Parser[ColumnRef] = + private val columnRef: Parser[ColumnRef] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) } - def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot } + private val dot: Parser[InteractableTerm] = "\\.".r ^^ { case _ => Dot } + + private val interaction: Parser[List[InteractableTerm]] = rep1sep(columnRef | dot, ":") + + private val term: Parser[Term] = intercept | + interaction ^^ { case terms => ColumnInteraction(terms) } | dot | columnRef - def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ { + private val terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ { case op ~ list => list.foldLeft(List(op)) { case (left, "+" ~ right) => left ++ Seq(right) case (left, "-" ~ right) => left ++ Seq(Deletion(right)) } } - def formula: Parser[ParsedRFormula] = + private val formula: Parser[ParsedRFormula] = (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } def parse(value: String): ParsedRFormula = parseAll(formula, value) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 2b1592930e..486274cd75 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -147,9 +147,8 @@ class StringIndexerModel ( } } - val outputColName = $(outputCol) val metadata = NominalAttribute.defaultAttr - .withName(outputColName).withValues(labels).toMetadata() + .withName($(inputCol)).withValues(labels).toMetadata() // If we are skipping invalid records, filter them out. val filteredDataset = (getHandleInvalid) match { case "skip" => { @@ -161,7 +160,7 @@ class StringIndexerModel ( case _ => dataset } filteredDataset.select(col("*"), - indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata)) + indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index 436e66bab0..53798c659d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -25,16 +25,24 @@ class RFormulaParserSuite extends SparkFunSuite { formula: String, label: String, terms: Seq[String], - schema: StructType = null) { + schema: StructType = new StructType) { val resolved = RFormulaParser.parse(formula).resolve(schema) assert(resolved.label == label) - assert(resolved.terms == terms) + val simpleTerms = terms.map { t => + if (t.contains(":")) { + t.split(":").toSeq + } else { + Seq(t) + } + } + assert(resolved.terms == simpleTerms) } test("parse simple formulas") { checkParse("y ~ x", "y", Seq("x")) checkParse("y ~ x + x", "y", Seq("x")) - checkParse("y ~ ._foo ", "y", Seq("._foo")) + checkParse("y~x+z", "y", Seq("x", "z")) + checkParse("y ~ ._fo..o ", "y", Seq("._fo..o")) checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) } @@ -79,4 +87,79 @@ class RFormulaParserSuite extends SparkFunSuite { assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept) assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept) } + + test("parse interactions") { + checkParse("y ~ a:b", "y", Seq("a:b")) + checkParse("y ~ ._a:._x", "y", Seq("._a:._x")) + checkParse("y ~ foo:bar", "y", Seq("foo:bar")) + checkParse("y ~ a : b : c", "y", Seq("a:b:c")) + checkParse("y ~ q + a:b:c + b:c + c:d + z", "y", Seq("q", "a:b:c", "b:c", "c:d", "z")) + } + + test("parse basic interactions with dot") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + .add("d", "string", true) + checkParse("a ~ .:b", "a", Seq("b", "c:b", "d:b"), schema) + checkParse("a ~ b:.", "a", Seq("b", "b:c", "b:d"), schema) + checkParse("a ~ .:b:.:.:c:d:.", "a", Seq("b:c:d"), schema) + } + + // Test data generated in R with terms.formula(y ~ .:., data = iris) + test("parse all to all iris interactions") { + val schema = (new StructType) + .add("Sepal.Length", "double", true) + .add("Sepal.Width", "double", true) + .add("Petal.Length", "double", true) + .add("Petal.Width", "double", true) + .add("Species", "string", true) + checkParse( + "y ~ .:.", + "y", + Seq( + "Sepal.Length", + "Sepal.Width", + "Petal.Length", + "Petal.Width", + "Species", + "Sepal.Length:Sepal.Width", + "Sepal.Length:Petal.Length", + "Sepal.Length:Petal.Width", + "Sepal.Length:Species", + "Sepal.Width:Petal.Length", + "Sepal.Width:Petal.Width", + "Sepal.Width:Species", + "Petal.Length:Petal.Width", + "Petal.Length:Species", + "Petal.Width:Species"), + schema) + } + + // Test data generated in R with terms.formula(y ~ .:. - Species:., data = iris) + test("parse interaction negation with iris") { + val schema = (new StructType) + .add("Sepal.Length", "double", true) + .add("Sepal.Width", "double", true) + .add("Petal.Length", "double", true) + .add("Petal.Width", "double", true) + .add("Species", "string", true) + checkParse("y ~ .:. - .:.", "y", Nil, schema) + checkParse( + "y ~ .:. - Species:.", + "y", + Seq( + "Sepal.Length", + "Sepal.Width", + "Petal.Length", + "Petal.Width", + "Sepal.Length:Sepal.Width", + "Sepal.Length:Petal.Length", + "Sepal.Length:Petal.Width", + "Sepal.Width:Petal.Length", + "Sepal.Width:Petal.Width", + "Petal.Length:Petal.Width"), + schema) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 6aed3243af..b56013008b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -118,9 +118,81 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { val expectedAttrs = new AttributeGroup( "features", Array( - new BinaryAttribute(Some("a__bar"), Some(1)), - new BinaryAttribute(Some("a__foo"), Some(2)), + new BinaryAttribute(Some("a_bar"), Some(1)), + new BinaryAttribute(Some("a_foo"), Some(2)), new NumericAttribute(Some("b"), Some(3)))) assert(attrs === expectedAttrs) } + + test("numeric interaction") { + val formula = new RFormula().setFormula("a ~ b:c:d") + val original = sqlContext.createDataFrame( + Seq((1, 2, 4, 2), (2, 3, 4, 1)) + ).toDF("a", "b", "c", "d") + val model = formula.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, 2, 4, 2, Vectors.dense(16.0), 1.0), + (2, 3, 4, 1, Vectors.dense(12.0), 2.0)) + ).toDF("a", "b", "c", "d", "features", "label") + assert(result.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1)))) + assert(attrs === expectedAttrs) + } + + test("factor numeric interaction") { + val formula = new RFormula().setFormula("id ~ a:b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), + (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0), 3.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_baz:b"), Some(1)), + new NumericAttribute(Some("a_bar:b"), Some(2)), + new NumericAttribute(Some("a_foo:b"), Some(3)))) + assert(attrs === expectedAttrs) + } + + test("factor factor interaction") { + val formula = new RFormula().setFormula("id ~ a:b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), + (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), + (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_bar:b_zq"), Some(1)), + new NumericAttribute(Some("a_bar:b_zz"), Some(2)), + new NumericAttribute(Some("a_foo:b_zq"), Some(3)), + new NumericAttribute(Some("a_foo:b_zz"), Some(4)))) + assert(attrs === expectedAttrs) + } } diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index f41d72f877..a4e60f916b 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1850,7 +1850,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): Implements the transforms required for fitting a dataset against an R model formula. Currently we support a limited subset of the R - operators, including '~', '+', '-', and '.'. Also see the R formula + operators, including '~', '.', ':', '+', and '-'. Also see the R formula docs: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html -- cgit v1.2.3