aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-10-13 13:24:10 -0700
committerXiangrui Meng <meng@databricks.com>2015-10-13 13:24:10 -0700
commit2b574f52d7bf51b1fe2a73086a3735b633e9083f (patch)
treec713972503cf3c51c6c077aa09621bb3c36690fd /mllib/src/test/scala/org/apache
parentc75f058b72d492d6de898957b3058f242d70dd8a (diff)
downloadspark-2b574f52d7bf51b1fe2a73086a3735b633e9083f.tar.gz
spark-2b574f52d7bf51b1fe2a73086a3735b633e9083f.tar.bz2
spark-2b574f52d7bf51b1fe2a73086a3735b633e9083f.zip
[SPARK-7402] [ML] JSON SerDe for standard param types
This PR implements the JSON SerDe for the following param types: `Boolean`, `Int`, `Long`, `Float`, `Double`, `String`, `Array[Int]`, `Array[Double]`, and `Array[String]`. The implementation of `Float`, `Double`, and `Array[Double]` are specialized to handle `NaN` and `Inf`s. This will be used in pipeline persistence. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #9090 from mengxr/SPARK-7402.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala114
1 files changed, 114 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index a2ea279f5d..eeb03dba2f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -21,6 +21,120 @@ import org.apache.spark.SparkFunSuite
class ParamsSuite extends SparkFunSuite {
+ test("json encode/decode") {
+ val dummy = new Params {
+ override def copy(extra: ParamMap): Params = defaultCopy(extra)
+
+ override val uid: String = "dummy"
+ }
+
+ { // BooleanParam
+ val param = new BooleanParam(dummy, "name", "doc")
+ for (value <- Seq(true, false)) {
+ val json = param.jsonEncode(value)
+ assert(param.jsonDecode(json) === value)
+ }
+ }
+
+ { // IntParam
+ val param = new IntParam(dummy, "name", "doc")
+ for (value <- Seq(Int.MinValue, -1, 0, 1, Int.MaxValue)) {
+ val json = param.jsonEncode(value)
+ assert(param.jsonDecode(json) === value)
+ }
+ }
+
+ { // LongParam
+ val param = new LongParam(dummy, "name", "doc")
+ for (value <- Seq(Long.MinValue, -1L, 0L, 1L, Long.MaxValue)) {
+ val json = param.jsonEncode(value)
+ assert(param.jsonDecode(json) === value)
+ }
+ }
+
+ { // FloatParam
+ val param = new FloatParam(dummy, "name", "doc")
+ for (value <- Seq(Float.NaN, Float.NegativeInfinity, Float.MinValue, -1.0f, -0.5f, 0.0f,
+ Float.MinPositiveValue, 0.5f, 1.0f, Float.MaxValue, Float.PositiveInfinity)) {
+ val json = param.jsonEncode(value)
+ val decoded = param.jsonDecode(json)
+ if (value.isNaN) {
+ assert(decoded.isNaN)
+ } else {
+ assert(decoded === value)
+ }
+ }
+ }
+
+ { // DoubleParam
+ val param = new DoubleParam(dummy, "name", "doc")
+ for (value <- Seq(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, -0.5, 0.0,
+ Double.MinPositiveValue, 0.5, 1.0, Double.MaxValue, Double.PositiveInfinity)) {
+ val json = param.jsonEncode(value)
+ val decoded = param.jsonDecode(json)
+ if (value.isNaN) {
+ assert(decoded.isNaN)
+ } else {
+ assert(decoded === value)
+ }
+ }
+ }
+
+ { // StringParam
+ val param = new Param[String](dummy, "name", "doc")
+ // Currently we do not support null.
+ for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) {
+ val json = param.jsonEncode(value)
+ assert(param.jsonDecode(json) === value)
+ }
+ }
+
+ { // IntArrayParam
+ val param = new IntArrayParam(dummy, "name", "doc")
+ val values: Seq[Array[Int]] = Seq(
+ Array(),
+ Array(1),
+ Array(Int.MinValue, 0, Int.MaxValue))
+ for (value <- values) {
+ val json = param.jsonEncode(value)
+ assert(param.jsonDecode(json) === value)
+ }
+ }
+
+ { // DoubleArrayParam
+ val param = new DoubleArrayParam(dummy, "name", "doc")
+ val values: Seq[Array[Double]] = Seq(
+ Array(),
+ Array(1.0),
+ Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
+ Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity))
+ for (value <- values) {
+ val json = param.jsonEncode(value)
+ val decoded = param.jsonDecode(json)
+ assert(decoded.length === value.length)
+ decoded.zip(value).foreach { case (actual, expected) =>
+ if (expected.isNaN) {
+ assert(actual.isNaN)
+ } else {
+ assert(actual === expected)
+ }
+ }
+ }
+ }
+
+ { // StringArrayParam
+ val param = new StringArrayParam(dummy, "name", "doc")
+ val values: Seq[Array[String]] = Seq(
+ Array(),
+ Array(""),
+ Array("", "1", "abc", "quote\"", "newline\n"))
+ for (value <- values) {
+ val json = param.jsonEncode(value)
+ assert(param.jsonDecode(json) === value)
+ }
+ }
+ }
+
test("param") {
val solver = new TestParams()
val uid = solver.uid