aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-30 08:37:56 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-30 08:37:56 -0700
commit19a6d192d53ce6dffe998ce110adab1f2efcb23e (patch)
tree6900926371373d8bb072d85441df7840918be1f9 /mllib
parente5fb78baf9a6014b6dd02cf9f528d069732aafca (diff)
downloadspark-19a6d192d53ce6dffe998ce110adab1f2efcb23e.tar.gz
spark-19a6d192d53ce6dffe998ce110adab1f2efcb23e.tar.bz2
spark-19a6d192d53ce6dffe998ce110adab1f2efcb23e.zip
[SPARK-15030][ML][SPARKR] Support formula in spark.kmeans in SparkR
## What changes were proposed in this pull request? * ```RFormula``` supports empty response variable like ```~ x + y```. * Support formula in ```spark.kmeans``` in SparkR. * Fix some outdated docs for SparkR. ## How was this patch tested? Unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #12813 from yanboliang/spark-15030.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala32
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala4
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala19
6 files changed, 50 insertions, 20 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 3ac6c77669..5219680be2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -214,7 +214,7 @@ class RFormulaModel private[feature](
override def transformSchema(schema: StructType): StructType = {
checkCanTransform(schema)
val withFeatures = pipelineModel.transformSchema(schema)
- if (hasLabelCol(withFeatures)) {
+ if (resolvedFormula.label.isEmpty || hasLabelCol(withFeatures)) {
withFeatures
} else if (schema.exists(_.name == resolvedFormula.label)) {
val nullable = schema(resolvedFormula.label).dataType match {
@@ -236,7 +236,7 @@ class RFormulaModel private[feature](
private def transformLabel(dataset: Dataset[_]): DataFrame = {
val labelName = resolvedFormula.label
- if (hasLabelCol(dataset.schema)) {
+ if (labelName.isEmpty || hasLabelCol(dataset.schema)) {
dataset.toDF
} else if (dataset.schema.exists(_.name == labelName)) {
dataset.schema(labelName).dataType match {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
index 4079b387e1..cf52710ab8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
@@ -63,6 +63,9 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
ResolvedRFormula(label.value, includedTerms.distinct, hasIntercept)
}
+ /** Whether this formula specifies fitting with response variable. */
+ def hasLabel: Boolean = label.value.nonEmpty
+
/** Whether this formula specifies fitting with an intercept term. */
def hasIntercept: Boolean = {
var intercept = true
@@ -159,6 +162,10 @@ private[ml] object RFormulaParser extends RegexParsers {
private val columnRef: Parser[ColumnRef] =
"([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) }
+ private val empty: Parser[ColumnRef] = "" ^^ { case a => ColumnRef("") }
+
+ private val label: Parser[ColumnRef] = columnRef | empty
+
private val dot: Parser[InteractableTerm] = "\\.".r ^^ { case _ => Dot }
private val interaction: Parser[List[InteractableTerm]] = rep1sep(columnRef | dot, ":")
@@ -174,7 +181,7 @@ private[ml] object RFormulaParser extends RegexParsers {
}
private val formula: Parser[ParsedRFormula] =
- (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
+ (label ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
case Success(result, _) => result
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index f67760d3ca..4d4c303fc8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -25,7 +25,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
-import org.apache.spark.ml.feature.VectorAssembler
+import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -65,28 +65,32 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
def fit(
data: DataFrame,
- k: Double,
- maxIter: Double,
- initMode: String,
- columns: Array[String]): KMeansWrapper = {
+ formula: String,
+ k: Int,
+ maxIter: Int,
+ initMode: String): KMeansWrapper = {
+
+ val rFormulaModel = new RFormula()
+ .setFormula(formula)
+ .setFeaturesCol("features")
+ .fit(data)
- val assembler = new VectorAssembler()
- .setInputCols(columns)
- .setOutputCol("features")
+ // get feature names from output schema
+ val schema = rFormulaModel.transform(data).schema
+ val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+ .attributes.get
+ val features = featureAttrs.map(_.name.get)
val kMeans = new KMeans()
- .setK(k.toInt)
- .setMaxIter(maxIter.toInt)
+ .setK(k)
+ .setMaxIter(maxIter)
.setInitMode(initMode)
val pipeline = new Pipeline()
- .setStages(Array(assembler, kMeans))
+ .setStages(Array(rFormulaModel, kMeans))
.fit(data)
val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
- val attrs = AttributeGroup.fromStructField(
- kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
- val features: Array[String] = attrs.attributes.get.map(_.name.get)
val size: Array[Long] = kMeansModel.summary.clusterSizes
new KMeansWrapper(pipeline, features, size)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index 9c0757941e..568c160ee5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -25,7 +25,7 @@ import org.apache.spark.SparkException
import org.apache.spark.ml.util.MLReader
/**
- * This is the Scala stub of SparkR ml.load. It will dispatch the call to corresponding
+ * This is the Scala stub of SparkR read.ml. It will dispatch the call to corresponding
* model wrapper loading function according the class name extracted from rMetadata of the path.
*/
private[r] object RWrappers extends MLReader[Object] {
@@ -45,7 +45,7 @@ private[r] object RWrappers extends MLReader[Object] {
case "org.apache.spark.ml.r.KMeansWrapper" =>
KMeansWrapper.load(path)
case _ =>
- throw new SparkException(s"SparkR ml.load does not support load $className")
+ throw new SparkException(s"SparkR read.ml does not support load $className")
}
}
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
index 66b2ceacb0..5f1d5987e8 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
@@ -72,7 +72,7 @@ public class JavaStatisticsSuite implements Serializable {
Double corr1 = Statistics.corr(x, y);
Double corr2 = Statistics.corr(x, y, "pearson");
// Check default method
- assertEquals(corr1, corr2);
+ assertEquals(corr1, corr2, 1e-5);
}
@Test
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index e1b269b5b6..f8476953d8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
@@ -89,6 +90,24 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(resultSchema.toString == model.transform(original).schema.toString)
}
+ test("allow empty label") {
+ val original = sqlContext.createDataFrame(
+ Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0))
+ ).toDF("id", "a", "b")
+ val formula = new RFormula().setFormula("~ a + b")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val resultSchema = model.transformSchema(original.schema)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)),
+ (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)),
+ (7, 8.0, 9.0, Vectors.dense(8.0, 9.0)))
+ ).toDF("id", "a", "b", "features")
+ assert(result.schema.toString == resultSchema.toString)
+ assert(result.collect() === expected.collect())
+ }
+
test("encodes string terms") {
val formula = new RFormula().setFormula("id ~ a + b")
val original = sqlContext.createDataFrame(