aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-10-13 13:24:10 -0700
committerXiangrui Meng <meng@databricks.com>2015-10-13 13:24:10 -0700
commit2b574f52d7bf51b1fe2a73086a3735b633e9083f (patch)
treec713972503cf3c51c6c077aa09621bb3c36690fd /mllib
parentc75f058b72d492d6de898957b3058f242d70dd8a (diff)
downloadspark-2b574f52d7bf51b1fe2a73086a3735b633e9083f.tar.gz
spark-2b574f52d7bf51b1fe2a73086a3735b633e9083f.tar.bz2
spark-2b574f52d7bf51b1fe2a73086a3735b633e9083f.zip
[SPARK-7402] [ML] JSON SerDe for standard param types
This PR implements the JSON SerDe for the following param types: `Boolean`, `Int`, `Long`, `Float`, `Double`, `String`, `Array[Int]`, `Array[Double]`, and `Array[String]`. The implementation of `Float`, `Double`, and `Array[Double]` are specialized to handle `NaN` and `Inf`s. This will be used in pipeline persistence. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #9090 from mengxr/SPARK-7402.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala169
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala114
2 files changed, 283 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index ec98b05e13..8361406f87 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -24,6 +24,9 @@ import scala.annotation.varargs
import scala.collection.mutable
import scala.collection.JavaConverters._
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.util.Identifiable
@@ -80,6 +83,30 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
/** Creates a param pair with the given value (for Scala). */
def ->(value: T): ParamPair[T] = ParamPair(this, value)
+ /** Encodes a param value into JSON, which can be decoded by [[jsonDecode()]]. */
+ def jsonEncode(value: T): String = {
+ value match {
+ case x: String =>
+ compact(render(JString(x)))
+ case _ =>
+ throw new NotImplementedError(
+ "The default jsonEncode only supports string. " +
+ s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.")
+ }
+ }
+
+ /** Decodes a param value from JSON. */
+ def jsonDecode(json: String): T = {
+ parse(json) match {
+ case JString(x) =>
+ x.asInstanceOf[T]
+ case _ =>
+ throw new NotImplementedError(
+ "The default jsonDecode only supports string. " +
+ s"${this.getClass.getName} must override jsonDecode to support its value type.")
+ }
+ }
+
override final def toString: String = s"${parent}__$name"
override final def hashCode: Int = toString.##
@@ -198,6 +225,46 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double =>
/** Creates a param pair with the given value (for Java). */
override def w(value: Double): ParamPair[Double] = super.w(value)
+
+ override def jsonEncode(value: Double): String = {
+ compact(render(DoubleParam.jValueEncode(value)))
+ }
+
+ override def jsonDecode(json: String): Double = {
+ DoubleParam.jValueDecode(parse(json))
+ }
+}
+
+private[param] object DoubleParam {
+ /** Encodes a param value into JValue. */
+ def jValueEncode(value: Double): JValue = {
+ value match {
+ case _ if value.isNaN =>
+ JString("NaN")
+ case Double.NegativeInfinity =>
+ JString("-Inf")
+ case Double.PositiveInfinity =>
+ JString("Inf")
+ case _ =>
+ JDouble(value)
+ }
+ }
+
+ /** Decodes a param value from JValue. */
+ def jValueDecode(jValue: JValue): Double = {
+ jValue match {
+ case JString("NaN") =>
+ Double.NaN
+ case JString("-Inf") =>
+ Double.NegativeInfinity
+ case JString("Inf") =>
+ Double.PositiveInfinity
+ case JDouble(x) =>
+ x
+ case _ =>
+ throw new IllegalArgumentException(s"Cannot decode $jValue to Double.")
+ }
+ }
}
/**
@@ -218,6 +285,15 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea
/** Creates a param pair with the given value (for Java). */
override def w(value: Int): ParamPair[Int] = super.w(value)
+
+ override def jsonEncode(value: Int): String = {
+ compact(render(JInt(value)))
+ }
+
+ override def jsonDecode(json: String): Int = {
+ implicit val formats = DefaultFormats
+ parse(json).extract[Int]
+ }
}
/**
@@ -238,6 +314,47 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo
/** Creates a param pair with the given value (for Java). */
override def w(value: Float): ParamPair[Float] = super.w(value)
+
+ override def jsonEncode(value: Float): String = {
+ compact(render(FloatParam.jValueEncode(value)))
+ }
+
+ override def jsonDecode(json: String): Float = {
+ FloatParam.jValueDecode(parse(json))
+ }
+}
+
+private object FloatParam {
+
+ /** Encodes a param value into JValue. */
+ def jValueEncode(value: Float): JValue = {
+ value match {
+ case _ if value.isNaN =>
+ JString("NaN")
+ case Float.NegativeInfinity =>
+ JString("-Inf")
+ case Float.PositiveInfinity =>
+ JString("Inf")
+ case _ =>
+ JDouble(value)
+ }
+ }
+
+ /** Decodes a param value from JValue. */
+ def jValueDecode(jValue: JValue): Float = {
+ jValue match {
+ case JString("NaN") =>
+ Float.NaN
+ case JString("-Inf") =>
+ Float.NegativeInfinity
+ case JString("Inf") =>
+ Float.PositiveInfinity
+ case JDouble(x) =>
+ x.toFloat
+ case _ =>
+ throw new IllegalArgumentException(s"Cannot decode $jValue to Float.")
+ }
+ }
}
/**
@@ -258,6 +375,15 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool
/** Creates a param pair with the given value (for Java). */
override def w(value: Long): ParamPair[Long] = super.w(value)
+
+ override def jsonEncode(value: Long): String = {
+ compact(render(JInt(value)))
+ }
+
+ override def jsonDecode(json: String): Long = {
+ implicit val formats = DefaultFormats
+ parse(json).extract[Long]
+ }
}
/**
@@ -272,6 +398,15 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV
/** Creates a param pair with the given value (for Java). */
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
+
+ override def jsonEncode(value: Boolean): String = {
+ compact(render(JBool(value)))
+ }
+
+ override def jsonDecode(json: String): Boolean = {
+ implicit val formats = DefaultFormats
+ parse(json).extract[Boolean]
+ }
}
/**
@@ -287,6 +422,16 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
+
+ override def jsonEncode(value: Array[String]): String = {
+ import org.json4s.JsonDSL._
+ compact(render(value.toSeq))
+ }
+
+ override def jsonDecode(json: String): Array[String] = {
+ implicit val formats = DefaultFormats
+ parse(json).extract[Seq[String]].toArray
+ }
}
/**
@@ -303,6 +448,20 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] =
w(value.asScala.map(_.asInstanceOf[Double]).toArray)
+
+ override def jsonEncode(value: Array[Double]): String = {
+ import org.json4s.JsonDSL._
+ compact(render(value.toSeq.map(DoubleParam.jValueEncode)))
+ }
+
+ override def jsonDecode(json: String): Array[Double] = {
+ parse(json) match {
+ case JArray(values) =>
+ values.map(DoubleParam.jValueDecode).toArray
+ case _ =>
+ throw new IllegalArgumentException(s"Cannot decode $json to Array[Double].")
+ }
+ }
}
/**
@@ -319,6 +478,16 @@ class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[In
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] =
w(value.asScala.map(_.asInstanceOf[Int]).toArray)
+
+ override def jsonEncode(value: Array[Int]): String = {
+ import org.json4s.JsonDSL._
+ compact(render(value.toSeq))
+ }
+
+ override def jsonDecode(json: String): Array[Int] = {
+ implicit val formats = DefaultFormats
+ parse(json).extract[Seq[Int]].toArray
+ }
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index a2ea279f5d..eeb03dba2f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -21,6 +21,120 @@ import org.apache.spark.SparkFunSuite
class ParamsSuite extends SparkFunSuite {
+ test("json encode/decode") {
+ val dummy = new Params {
+ override def copy(extra: ParamMap): Params = defaultCopy(extra)
+
+ override val uid: String = "dummy"
+ }
+
+ { // BooleanParam
+ val param = new BooleanParam(dummy, "name", "doc")
+ for (value <- Seq(true, false)) {
+ val json = param.jsonEncode(value)
+ assert(param.jsonDecode(json) === value)
+ }
+ }
+
+ { // IntParam
+ val param = new IntParam(dummy, "name", "doc")
+ for (value <- Seq(Int.MinValue, -1, 0, 1, Int.MaxValue)) {
+ val json = param.jsonEncode(value)
+ assert(param.jsonDecode(json) === value)
+ }
+ }
+
+ { // LongParam
+ val param = new LongParam(dummy, "name", "doc")
+ for (value <- Seq(Long.MinValue, -1L, 0L, 1L, Long.MaxValue)) {
+ val json = param.jsonEncode(value)
+ assert(param.jsonDecode(json) === value)
+ }
+ }
+
+ { // FloatParam
+ val param = new FloatParam(dummy, "name", "doc")
+ for (value <- Seq(Float.NaN, Float.NegativeInfinity, Float.MinValue, -1.0f, -0.5f, 0.0f,
+ Float.MinPositiveValue, 0.5f, 1.0f, Float.MaxValue, Float.PositiveInfinity)) {
+ val json = param.jsonEncode(value)
+ val decoded = param.jsonDecode(json)
+ if (value.isNaN) {
+ assert(decoded.isNaN)
+ } else {
+ assert(decoded === value)
+ }
+ }
+ }
+
+ { // DoubleParam
+ val param = new DoubleParam(dummy, "name", "doc")
+ for (value <- Seq(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, -0.5, 0.0,
+ Double.MinPositiveValue, 0.5, 1.0, Double.MaxValue, Double.PositiveInfinity)) {
+ val json = param.jsonEncode(value)
+ val decoded = param.jsonDecode(json)
+ if (value.isNaN) {
+ assert(decoded.isNaN)
+ } else {
+ assert(decoded === value)
+ }
+ }
+ }
+
+ { // StringParam
+ val param = new Param[String](dummy, "name", "doc")
+ // Currently we do not support null.
+ for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) {
+ val json = param.jsonEncode(value)
+ assert(param.jsonDecode(json) === value)
+ }
+ }
+
+ { // IntArrayParam
+ val param = new IntArrayParam(dummy, "name", "doc")
+ val values: Seq[Array[Int]] = Seq(
+ Array(),
+ Array(1),
+ Array(Int.MinValue, 0, Int.MaxValue))
+ for (value <- values) {
+ val json = param.jsonEncode(value)
+ assert(param.jsonDecode(json) === value)
+ }
+ }
+
+ { // DoubleArrayParam
+ val param = new DoubleArrayParam(dummy, "name", "doc")
+ val values: Seq[Array[Double]] = Seq(
+ Array(),
+ Array(1.0),
+ Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
+ Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity))
+ for (value <- values) {
+ val json = param.jsonEncode(value)
+ val decoded = param.jsonDecode(json)
+ assert(decoded.length === value.length)
+ decoded.zip(value).foreach { case (actual, expected) =>
+ if (expected.isNaN) {
+ assert(actual.isNaN)
+ } else {
+ assert(actual === expected)
+ }
+ }
+ }
+ }
+
+ { // StringArrayParam
+ val param = new StringArrayParam(dummy, "name", "doc")
+ val values: Seq[Array[String]] = Seq(
+ Array(),
+ Array(""),
+ Array("", "1", "abc", "quote\"", "newline\n"))
+ for (value <- values) {
+ val json = param.jsonEncode(value)
+ assert(param.jsonDecode(json) === value)
+ }
+ }
+ }
+
test("param") {
val solver = new TestParams()
val uid = solver.uid