aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-11-17 14:04:49 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-17 14:04:49 -0800
commit3e9e6380236985ec5b51b459f8c61f964a76ff8b (patch)
treec01c4d4d3dfd6b58477c0d8dd53d76025bdde965 /mllib/src/test
parent6eb7008b7f33a36b06d0615b68cc21ed90ad1d8a (diff)
downloadspark-3e9e6380236985ec5b51b459f8c61f964a76ff8b.tar.gz
spark-3e9e6380236985ec5b51b459f8c61f964a76ff8b.tar.bz2
spark-3e9e6380236985ec5b51b459f8c61f964a76ff8b.zip
[SPARK-11764][ML] make Param.jsonEncode/jsonDecode support Vector
This PR makes the default read/write work with simple transformers/estimators that have params of type `Param[Vector]`. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #9776 from mengxr/SPARK-11764.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala22
1 files changed, 18 insertions, 4 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index eeb03dba2f..a1878be747 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.param
import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
class ParamsSuite extends SparkFunSuite {
@@ -80,7 +81,7 @@ class ParamsSuite extends SparkFunSuite {
}
}
- { // StringParam
+ { // Param[String]
val param = new Param[String](dummy, "name", "doc")
// Currently we do not support null.
for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) {
@@ -89,6 +90,19 @@ class ParamsSuite extends SparkFunSuite {
}
}
+ { // Param[Vector]
+ val param = new Param[Vector](dummy, "name", "doc")
+ val values = Seq(
+ Vectors.dense(Array.empty[Double]),
+ Vectors.dense(0.0, 2.0),
+ Vectors.sparse(0, Array.empty, Array.empty),
+ Vectors.sparse(2, Array(1), Array(2.0)))
+ for (value <- values) {
+ val json = param.jsonEncode(value)
+ assert(param.jsonDecode(json) === value)
+ }
+ }
+
{ // IntArrayParam
val param = new IntArrayParam(dummy, "name", "doc")
val values: Seq[Array[Int]] = Seq(
@@ -138,7 +152,7 @@ class ParamsSuite extends SparkFunSuite {
test("param") {
val solver = new TestParams()
val uid = solver.uid
- import solver.{maxIter, inputCol}
+ import solver.{inputCol, maxIter}
assert(maxIter.name === "maxIter")
assert(maxIter.doc === "maximum number of iterations (>= 0)")
@@ -181,7 +195,7 @@ class ParamsSuite extends SparkFunSuite {
test("param map") {
val solver = new TestParams()
- import solver.{maxIter, inputCol}
+ import solver.{inputCol, maxIter}
val map0 = ParamMap.empty
@@ -220,7 +234,7 @@ class ParamsSuite extends SparkFunSuite {
test("params") {
val solver = new TestParams()
- import solver.{handleInvalid, maxIter, inputCol}
+ import solver.{handleInvalid, inputCol, maxIter}
val params = solver.params
assert(params.length === 3)