aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2015-09-25 00:43:22 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-25 00:43:22 -0700
commit922338812c03eba43f2f1a6c414d1b6b049811cf (patch)
tree2df940a08de0645e2b88ba69d0c63931f9ec1f2f
parent21fd12cb17b9e08a0cc49b4fda801af947a4183b (diff)
downloadspark-922338812c03eba43f2f1a6c414d1b6b049811cf.tar.gz
spark-922338812c03eba43f2f1a6c414d1b6b049811cf.tar.bz2
spark-922338812c03eba43f2f1a6c414d1b6b049811cf.zip
[SPARK-9681] [ML] Support R feature interactions in RFormula
This integrates the Interaction feature transformer with SparkR R formula support (i.e. support `:`). To generate reasonable ML attribute names for feature interactions, it was necessary to add the ability to read attribute the original attribute names back from `StructField`, and also to specify custom group prefixes in `VectorAssembler`. This also has the side-benefit of cleaning up the double-underscores in the attributes generated for non-interaction terms. mengxr Author: Eric Liang <ekl@databricks.com> Closes #8830 from ericl/interaction-2.
-rw-r--r--R/pkg/R/mllib.R2
-rw-r--r--R/pkg/inst/tests/test_mllib.R10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala113
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala97
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala89
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala76
-rw-r--r--python/pyspark/ml/feature.py2
10 files changed, 362 insertions, 60 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index cea3d760d0..474ada5956 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj"))
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
#'
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
-#' operators are supported, including '~', '+', '-', and '.'.
+#' operators are supported, including '~', '.', ':', '+', and '-'.
#' @param data DataFrame for training
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
#' @param lambda Regularization parameter
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index f272de78ad..032f8ec68b 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -49,6 +49,14 @@ test_that("dot minus and intercept vs native glm", {
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})
+test_that("feature interaction vs native glm", {
+ training <- createDataFrame(sqlContext, iris)
+ model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training)
+ vals <- collect(select(predict(model, training), "prediction"))
+ rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+})
+
test_that("summary coefficients match with native glm", {
training <- createDataFrame(sqlContext, iris)
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
@@ -57,5 +65,5 @@ test_that("summary coefficients match with native glm", {
expect_true(all(abs(rCoefs - coefs) < 1e-6))
expect_true(all(
as.character(stats$features) ==
- c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
+ c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
})
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
index e479f16902..a7c10333c0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
@@ -124,18 +124,28 @@ private[attribute] trait AttributeFactory {
private[attribute] def fromMetadata(metadata: Metadata): Attribute
/**
- * Creates an [[Attribute]] from a [[StructField]] instance.
+ * Creates an [[Attribute]] from a [[StructField]] instance, optionally preserving name.
*/
- def fromStructField(field: StructField): Attribute = {
+ private[ml] def decodeStructField(field: StructField, preserveName: Boolean): Attribute = {
require(field.dataType.isInstanceOf[NumericType])
val metadata = field.metadata
val mlAttr = AttributeKeys.ML_ATTR
if (metadata.contains(mlAttr)) {
- fromMetadata(metadata.getMetadata(mlAttr)).withName(field.name)
+ val attr = fromMetadata(metadata.getMetadata(mlAttr))
+ if (preserveName) {
+ attr
+ } else {
+ attr.withName(field.name)
+ }
} else {
UnresolvedAttribute
}
}
+
+ /**
+ * Creates an [[Attribute]] from a [[StructField]] instance.
+ */
+ def fromStructField(field: StructField): Attribute = decodeStructField(field, false)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
index 9194763fb3..37f7862476 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
@@ -149,8 +149,14 @@ class Interaction(override val uid: String) extends Transformer
features.reverse.foreach { f =>
val encodedAttrs = f.dataType match {
case _: NumericType | BooleanType =>
- val attr = Attribute.fromStructField(f)
- encodedFeatureAttrs(Seq(attr), None)
+ val attr = Attribute.decodeStructField(f, preserveName = true)
+ if (attr == UnresolvedAttribute) {
+ encodedFeatureAttrs(Seq(NumericAttribute.defaultAttr.withName(f.name)), None)
+ } else if (!attr.name.isDefined) {
+ encodedFeatureAttrs(Seq(attr.withName(f.name)), None)
+ } else {
+ encodedFeatureAttrs(Seq(attr), None)
+ }
case _: VectorUDT =>
val group = AttributeGroup.fromStructField(f)
encodedFeatureAttrs(group.attributes.get, Some(group.name))
@@ -221,7 +227,7 @@ class Interaction(override val uid: String) extends Transformer
* count is equal to the number of categories. For numeric features the count
* should be set to 1.
*/
-private[ml] class FeatureEncoder(numFeatures: Array[Int]) {
+private[ml] class FeatureEncoder(numFeatures: Array[Int]) extends Serializable {
assert(numFeatures.forall(_ > 0), "Features counts must all be positive.")
/** The size of the output vector. */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index dcd6fe3c40..f9b840097f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
@@ -42,8 +43,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
/**
* :: Experimental ::
* Implements the transforms required for fitting a dataset against an R model formula. Currently
- * we support a limited subset of the R operators, including '.', '~', '+', and '-'. Also see the
- * R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
+ * we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. Also see
+ * the R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
*/
@Experimental
class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase {
@@ -82,36 +83,54 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
require(isDefined(formula), "Formula must be defined first.")
val parsedFormula = RFormulaParser.parse($(formula))
val resolvedFormula = parsedFormula.resolve(dataset.schema)
- // StringType terms and terms representing interactions need to be encoded before assembly.
- // TODO(ekl) add support for feature interactions
val encoderStages = ArrayBuffer[PipelineStage]()
+
+ val prefixesToRewrite = mutable.Map[String, String]()
val tempColumns = ArrayBuffer[String]()
- val takenNames = mutable.Set(dataset.columns: _*)
- val encodedTerms = resolvedFormula.terms.map { term =>
+ def tmpColumn(category: String): String = {
+ val col = Identifiable.randomUID(category)
+ tempColumns += col
+ col
+ }
+
+ // First we index each string column referenced by the input terms.
+ val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term =>
dataset.schema(term) match {
case column if column.dataType == StringType =>
- val indexCol = term + "_idx_" + uid
- val encodedCol = {
- var tmp = term
- while (takenNames.contains(tmp)) {
- tmp += "_"
- }
- tmp
- }
- takenNames.add(indexCol)
- takenNames.add(encodedCol)
- encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol)
- encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
- tempColumns += indexCol
- tempColumns += encodedCol
- encodedCol
+ val indexCol = tmpColumn("stridx")
+ encoderStages += new StringIndexer()
+ .setInputCol(term)
+ .setOutputCol(indexCol)
+ (term, indexCol)
case _ =>
- term
+ (term, term)
}
+ }.toMap
+
+ // Then we handle one-hot encoding and interactions between terms.
+ val encodedTerms = resolvedFormula.terms.map {
+ case Seq(term) if dataset.schema(term).dataType == StringType =>
+ val encodedCol = tmpColumn("onehot")
+ encoderStages += new OneHotEncoder()
+ .setInputCol(indexed(term))
+ .setOutputCol(encodedCol)
+ prefixesToRewrite(encodedCol + "_") = term + "_"
+ encodedCol
+ case Seq(term) =>
+ term
+ case terms =>
+ val interactionCol = tmpColumn("interaction")
+ encoderStages += new Interaction()
+ .setInputCols(terms.map(indexed).toArray)
+ .setOutputCol(interactionCol)
+ prefixesToRewrite(interactionCol + "_") = ""
+ interactionCol
}
+
encoderStages += new VectorAssembler(uid)
.setInputCols(encodedTerms.toArray)
.setOutputCol($(featuresCol))
+ encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)
encoderStages += new ColumnPruner(tempColumns.toSet)
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this))
@@ -218,3 +237,53 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
}
+
+/**
+ * Utility transformer that rewrites Vector attribute names via prefix replacement. For example,
+ * it can rewrite attribute names starting with 'foo_' to start with 'bar_' instead.
+ *
+ * @param vectorCol name of the vector column to rewrite.
+ * @param prefixesToRewrite the map of string prefixes to their replacement values. Each attribute
+ * name defined in vectorCol will be checked against the keys of this
+ * map. When a key prefixes a name, the matching prefix will be replaced
+ * by the value in the map.
+ */
+private class VectorAttributeRewriter(
+ vectorCol: String,
+ prefixesToRewrite: Map[String, String])
+ extends Transformer {
+
+ override val uid = Identifiable.randomUID("vectorAttrRewriter")
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ val metadata = {
+ val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
+ val attrs = group.attributes.get.map { attr =>
+ if (attr.name.isDefined) {
+ val name = attr.name.get
+ val replacement = prefixesToRewrite.filter { case (k, _) => name.startsWith(k) }
+ if (replacement.nonEmpty) {
+ val (k, v) = replacement.headOption.get
+ attr.withName(v + name.stripPrefix(k))
+ } else {
+ attr
+ }
+ } else {
+ attr
+ }
+ }
+ new AttributeGroup(vectorCol, attrs).toMetadata()
+ }
+ val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col)
+ val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata)
+ dataset.select((otherCols :+ rewrittenCol): _*)
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ StructType(
+ schema.fields.filter(_.name != vectorCol) ++
+ schema.fields.filter(_.name == vectorCol))
+ }
+
+ override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
index 1ca3b92a7d..4079b387e1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import scala.collection.mutable
import scala.util.parsing.combinator.RegexParsers
import org.apache.spark.mllib.linalg.VectorUDT
@@ -31,27 +32,35 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
* of the special '.' term. Duplicate terms will be removed during resolution.
*/
def resolve(schema: StructType): ResolvedRFormula = {
- var includedTerms = Seq[String]()
+ val dotTerms = expandDot(schema)
+ var includedTerms = Seq[Seq[String]]()
terms.foreach {
+ case col: ColumnRef =>
+ includedTerms :+= Seq(col.value)
+ case ColumnInteraction(cols) =>
+ includedTerms ++= expandInteraction(schema, cols)
case Dot =>
- includedTerms ++= simpleTypes(schema).filter(_ != label.value)
- case ColumnRef(value) =>
- includedTerms :+= value
+ includedTerms ++= dotTerms.map(Seq(_))
case Deletion(term: Term) =>
term match {
- case ColumnRef(value) =>
- includedTerms = includedTerms.filter(_ != value)
+ case inner: ColumnRef =>
+ includedTerms = includedTerms.filter(_ != Seq(inner.value))
+ case ColumnInteraction(cols) =>
+ val fromInteraction = expandInteraction(schema, cols).map(_.toSet)
+ includedTerms = includedTerms.filter(t => !fromInteraction.contains(t.toSet))
case Dot =>
// e.g. "- .", which removes all first-order terms
- val fromSchema = simpleTypes(schema)
- includedTerms = includedTerms.filter(fromSchema.contains(_))
+ includedTerms = includedTerms.filter {
+ case Seq(t) => !dotTerms.contains(t)
+ case _ => true
+ }
case _: Deletion =>
- assert(false, "Deletion terms cannot be nested")
+ throw new RuntimeException("Deletion terms cannot be nested")
case _: Intercept =>
}
case _: Intercept =>
}
- ResolvedRFormula(label.value, includedTerms.distinct)
+ ResolvedRFormula(label.value, includedTerms.distinct, hasIntercept)
}
/** Whether this formula specifies fitting with an intercept term. */
@@ -67,19 +76,54 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
intercept
}
+ // expands the Dot operators in interaction terms
+ private def expandInteraction(
+ schema: StructType, terms: Seq[InteractableTerm]): Seq[Seq[String]] = {
+ if (terms.isEmpty) {
+ return Seq(Nil)
+ }
+
+ val rest = expandInteraction(schema, terms.tail)
+ val validInteractions = (terms.head match {
+ case Dot =>
+ expandDot(schema).flatMap { t =>
+ rest.map { r =>
+ Seq(t) ++ r
+ }
+ }
+ case ColumnRef(value) =>
+ rest.map(Seq(value) ++ _)
+ }).map(_.distinct)
+
+ // Deduplicates feature interactions, for example, a:b is the same as b:a.
+ var seen = mutable.Set[Set[String]]()
+ validInteractions.flatMap {
+ case t if seen.contains(t.toSet) =>
+ None
+ case t =>
+ seen += t.toSet
+ Some(t)
+ }.sortBy(_.length)
+ }
+
// the dot operator excludes complex column types
- private def simpleTypes(schema: StructType): Seq[String] = {
+ private def expandDot(schema: StructType): Seq[String] = {
schema.fields.filter(_.dataType match {
case _: NumericType | StringType | BooleanType | _: VectorUDT => true
case _ => false
- }).map(_.name)
+ }).map(_.name).filter(_ != label.value)
}
}
/**
* Represents a fully evaluated and simplified R formula.
+ * @param label the column name of the R formula label (response variable).
+ * @param terms the simplified terms of the R formula. Interactions terms are represented as Seqs
+ * of column names; non-interaction terms as length 1 Seqs.
+ * @param hasIntercept whether the formula specifies fitting with an intercept.
*/
-private[ml] case class ResolvedRFormula(label: String, terms: Seq[String])
+private[ml] case class ResolvedRFormula(
+ label: String, terms: Seq[Seq[String]], hasIntercept: Boolean)
/**
* R formula terms. See the R formula docs here for more information:
@@ -87,11 +131,17 @@ private[ml] case class ResolvedRFormula(label: String, terms: Seq[String])
*/
private[ml] sealed trait Term
+/** A term that may be part of an interaction, e.g. 'x' in 'x:y' */
+private[ml] sealed trait InteractableTerm extends Term
+
/* R formula reference to all available columns, e.g. "." in a formula */
-private[ml] case object Dot extends Term
+private[ml] case object Dot extends InteractableTerm
/* R formula reference to a column, e.g. "+ Species" in a formula */
-private[ml] case class ColumnRef(value: String) extends Term
+private[ml] case class ColumnRef(value: String) extends InteractableTerm
+
+/* R formula interaction of several columns, e.g. "Sepal_Length:Species" in a formula */
+private[ml] case class ColumnInteraction(terms: Seq[InteractableTerm]) extends Term
/* R formula intercept toggle, e.g. "+ 0" in a formula */
private[ml] case class Intercept(enabled: Boolean) extends Term
@@ -100,25 +150,30 @@ private[ml] case class Intercept(enabled: Boolean) extends Term
private[ml] case class Deletion(term: Term) extends Term
/**
- * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.'.
+ * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.', ':'.
*/
private[ml] object RFormulaParser extends RegexParsers {
- def intercept: Parser[Intercept] =
+ private val intercept: Parser[Intercept] =
"([01])".r ^^ { case a => Intercept(a == "1") }
- def columnRef: Parser[ColumnRef] =
+ private val columnRef: Parser[ColumnRef] =
"([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) }
- def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot }
+ private val dot: Parser[InteractableTerm] = "\\.".r ^^ { case _ => Dot }
+
+ private val interaction: Parser[List[InteractableTerm]] = rep1sep(columnRef | dot, ":")
+
+ private val term: Parser[Term] = intercept |
+ interaction ^^ { case terms => ColumnInteraction(terms) } | dot | columnRef
- def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ {
+ private val terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ {
case op ~ list => list.foldLeft(List(op)) {
case (left, "+" ~ right) => left ++ Seq(right)
case (left, "-" ~ right) => left ++ Seq(Deletion(right))
}
}
- def formula: Parser[ParsedRFormula] =
+ private val formula: Parser[ParsedRFormula] =
(columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 2b1592930e..486274cd75 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -147,9 +147,8 @@ class StringIndexerModel (
}
}
- val outputColName = $(outputCol)
val metadata = NominalAttribute.defaultAttr
- .withName(outputColName).withValues(labels).toMetadata()
+ .withName($(inputCol)).withValues(labels).toMetadata()
// If we are skipping invalid records, filter them out.
val filteredDataset = (getHandleInvalid) match {
case "skip" => {
@@ -161,7 +160,7 @@ class StringIndexerModel (
case _ => dataset
}
filteredDataset.select(col("*"),
- indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata))
+ indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
}
override def transformSchema(schema: StructType): StructType = {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
index 436e66bab0..53798c659d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
@@ -25,16 +25,24 @@ class RFormulaParserSuite extends SparkFunSuite {
formula: String,
label: String,
terms: Seq[String],
- schema: StructType = null) {
+ schema: StructType = new StructType) {
val resolved = RFormulaParser.parse(formula).resolve(schema)
assert(resolved.label == label)
- assert(resolved.terms == terms)
+ val simpleTerms = terms.map { t =>
+ if (t.contains(":")) {
+ t.split(":").toSeq
+ } else {
+ Seq(t)
+ }
+ }
+ assert(resolved.terms == simpleTerms)
}
test("parse simple formulas") {
checkParse("y ~ x", "y", Seq("x"))
checkParse("y ~ x + x", "y", Seq("x"))
- checkParse("y ~ ._foo ", "y", Seq("._foo"))
+ checkParse("y~x+z", "y", Seq("x", "z"))
+ checkParse("y ~ ._fo..o ", "y", Seq("._fo..o"))
checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
}
@@ -79,4 +87,79 @@ class RFormulaParserSuite extends SparkFunSuite {
assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept)
assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept)
}
+
+ test("parse interactions") {
+ checkParse("y ~ a:b", "y", Seq("a:b"))
+ checkParse("y ~ ._a:._x", "y", Seq("._a:._x"))
+ checkParse("y ~ foo:bar", "y", Seq("foo:bar"))
+ checkParse("y ~ a : b : c", "y", Seq("a:b:c"))
+ checkParse("y ~ q + a:b:c + b:c + c:d + z", "y", Seq("q", "a:b:c", "b:c", "c:d", "z"))
+ }
+
+ test("parse basic interactions with dot") {
+ val schema = (new StructType)
+ .add("a", "int", true)
+ .add("b", "long", false)
+ .add("c", "string", true)
+ .add("d", "string", true)
+ checkParse("a ~ .:b", "a", Seq("b", "c:b", "d:b"), schema)
+ checkParse("a ~ b:.", "a", Seq("b", "b:c", "b:d"), schema)
+ checkParse("a ~ .:b:.:.:c:d:.", "a", Seq("b:c:d"), schema)
+ }
+
+ // Test data generated in R with terms.formula(y ~ .:., data = iris)
+ test("parse all to all iris interactions") {
+ val schema = (new StructType)
+ .add("Sepal.Length", "double", true)
+ .add("Sepal.Width", "double", true)
+ .add("Petal.Length", "double", true)
+ .add("Petal.Width", "double", true)
+ .add("Species", "string", true)
+ checkParse(
+ "y ~ .:.",
+ "y",
+ Seq(
+ "Sepal.Length",
+ "Sepal.Width",
+ "Petal.Length",
+ "Petal.Width",
+ "Species",
+ "Sepal.Length:Sepal.Width",
+ "Sepal.Length:Petal.Length",
+ "Sepal.Length:Petal.Width",
+ "Sepal.Length:Species",
+ "Sepal.Width:Petal.Length",
+ "Sepal.Width:Petal.Width",
+ "Sepal.Width:Species",
+ "Petal.Length:Petal.Width",
+ "Petal.Length:Species",
+ "Petal.Width:Species"),
+ schema)
+ }
+
+ // Test data generated in R with terms.formula(y ~ .:. - Species:., data = iris)
+ test("parse interaction negation with iris") {
+ val schema = (new StructType)
+ .add("Sepal.Length", "double", true)
+ .add("Sepal.Width", "double", true)
+ .add("Petal.Length", "double", true)
+ .add("Petal.Width", "double", true)
+ .add("Species", "string", true)
+ checkParse("y ~ .:. - .:.", "y", Nil, schema)
+ checkParse(
+ "y ~ .:. - Species:.",
+ "y",
+ Seq(
+ "Sepal.Length",
+ "Sepal.Width",
+ "Petal.Length",
+ "Petal.Width",
+ "Sepal.Length:Sepal.Width",
+ "Sepal.Length:Petal.Length",
+ "Sepal.Length:Petal.Width",
+ "Sepal.Width:Petal.Length",
+ "Sepal.Width:Petal.Width",
+ "Petal.Length:Petal.Width"),
+ schema)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 6aed3243af..b56013008b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -118,9 +118,81 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
val expectedAttrs = new AttributeGroup(
"features",
Array(
- new BinaryAttribute(Some("a__bar"), Some(1)),
- new BinaryAttribute(Some("a__foo"), Some(2)),
+ new BinaryAttribute(Some("a_bar"), Some(1)),
+ new BinaryAttribute(Some("a_foo"), Some(2)),
new NumericAttribute(Some("b"), Some(3))))
assert(attrs === expectedAttrs)
}
+
+ test("numeric interaction") {
+ val formula = new RFormula().setFormula("a ~ b:c:d")
+ val original = sqlContext.createDataFrame(
+ Seq((1, 2, 4, 2), (2, 3, 4, 1))
+ ).toDF("a", "b", "c", "d")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (1, 2, 4, 2, Vectors.dense(16.0), 1.0),
+ (2, 3, 4, 1, Vectors.dense(12.0), 2.0))
+ ).toDF("a", "b", "c", "d", "features", "label")
+ assert(result.collect() === expected.collect())
+ val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1))))
+ assert(attrs === expectedAttrs)
+ }
+
+ test("factor numeric interaction") {
+ val formula = new RFormula().setFormula("id ~ a:b")
+ val original = sqlContext.createDataFrame(
+ Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5))
+ ).toDF("id", "a", "b")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0),
+ (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0),
+ (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0), 3.0),
+ (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0),
+ (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0),
+ (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0))
+ ).toDF("id", "a", "b", "features", "label")
+ assert(result.collect() === expected.collect())
+ val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array[Attribute](
+ new NumericAttribute(Some("a_baz:b"), Some(1)),
+ new NumericAttribute(Some("a_bar:b"), Some(2)),
+ new NumericAttribute(Some("a_foo:b"), Some(3))))
+ assert(attrs === expectedAttrs)
+ }
+
+ test("factor factor interaction") {
+ val formula = new RFormula().setFormula("id ~ a:b")
+ val original = sqlContext.createDataFrame(
+ Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
+ ).toDF("id", "a", "b")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0),
+ (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0),
+ (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0))
+ ).toDF("id", "a", "b", "features", "label")
+ assert(result.collect() === expected.collect())
+ val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array[Attribute](
+ new NumericAttribute(Some("a_bar:b_zq"), Some(1)),
+ new NumericAttribute(Some("a_bar:b_zz"), Some(2)),
+ new NumericAttribute(Some("a_foo:b_zq"), Some(3)),
+ new NumericAttribute(Some("a_foo:b_zz"), Some(4))))
+ assert(attrs === expectedAttrs)
+ }
}
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index f41d72f877..a4e60f916b 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -1850,7 +1850,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
Implements the transforms required for fitting a dataset against an
R model formula. Currently we support a limited subset of the R
- operators, including '~', '+', '-', and '.'. Also see the R formula
+ operators, including '~', '.', ':', '+', and '-'. Also see the R formula
docs:
http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html