aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-30 08:37:56 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-30 08:37:56 -0700
commit19a6d192d53ce6dffe998ce110adab1f2efcb23e (patch)
tree6900926371373d8bb072d85441df7840918be1f9 /mllib/src/test/scala
parente5fb78baf9a6014b6dd02cf9f528d069732aafca (diff)
downloadspark-19a6d192d53ce6dffe998ce110adab1f2efcb23e.tar.gz
spark-19a6d192d53ce6dffe998ce110adab1f2efcb23e.tar.bz2
spark-19a6d192d53ce6dffe998ce110adab1f2efcb23e.zip
[SPARK-15030][ML][SPARKR] Support formula in spark.kmeans in SparkR
## What changes were proposed in this pull request? * ```RFormula``` supports empty response variable like ```~ x + y```. * Support formula in ```spark.kmeans``` in SparkR. * Fix some outdated docs for SparkR. ## How was this patch tested? Unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #12813 from yanboliang/spark-15030.
Diffstat (limited to 'mllib/src/test/scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala19
1 files changed, 19 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index e1b269b5b6..f8476953d8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
@@ -89,6 +90,24 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(resultSchema.toString == model.transform(original).schema.toString)
}
+ test("allow empty label") {
+ val original = sqlContext.createDataFrame(
+ Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0))
+ ).toDF("id", "a", "b")
+ val formula = new RFormula().setFormula("~ a + b")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val resultSchema = model.transformSchema(original.schema)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)),
+ (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)),
+ (7, 8.0, 9.0, Vectors.dense(8.0, 9.0)))
+ ).toDF("id", "a", "b", "features")
+ assert(result.schema.toString == resultSchema.toString)
+ assert(result.collect() === expected.collect())
+ }
+
test("encodes string terms") {
val formula = new RFormula().setFormula("id ~ a + b")
val original = sqlContext.createDataFrame(