diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-11-17 14:04:49 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-11-17 14:04:49 -0800 |
commit | 3e9e6380236985ec5b51b459f8c61f964a76ff8b (patch) | |
tree | c01c4d4d3dfd6b58477c0d8dd53d76025bdde965 /mllib/src/main | |
parent | 6eb7008b7f33a36b06d0615b68cc21ed90ad1d8a (diff) | |
download | spark-3e9e6380236985ec5b51b459f8c61f964a76ff8b.tar.gz spark-3e9e6380236985ec5b51b459f8c61f964a76ff8b.tar.bz2 spark-3e9e6380236985ec5b51b459f8c61f964a76ff8b.zip |
[SPARK-11764][ML] make Param.jsonEncode/jsonDecode support Vector
This PR makes the default read/write work with simple transformers/estimators that have params of type `Param[Vector]`. jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Closes #9776 from mengxr/SPARK-11764.
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index c932570918..d182b0a988 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -29,6 +29,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: DeveloperApi :: @@ -88,9 +89,11 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali value match { case x: String => compact(render(JString(x))) + case v: Vector => + v.toJson case _ => throw new NotImplementedError( - "The default jsonEncode only supports string. " + + "The default jsonEncode only supports string and vector. " + s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.") } } @@ -100,9 +103,14 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali parse(json) match { case JString(x) => x.asInstanceOf[T] + case JObject(v) => + val keys = v.map(_._1) + assert(keys.contains("type") && keys.contains("values"), + s"Expect a JSON serialized vector but cannot find fields 'type' and 'values' in $json.") + Vectors.fromJson(json).asInstanceOf[T] case _ => throw new NotImplementedError( - "The default jsonDecode only supports string. " + + "The default jsonDecode only supports string and vector. " + s"${this.getClass.getName} must override jsonDecode to support its value type.") } } |