aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-11-17 10:17:16 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-17 10:17:16 -0800
commit21fac5434174389e8b83a2f11341fa7c9e360bfd (patch)
tree3e34fff4c100b7bbf273b95f87bdd6a456028937 /mllib/src/main
parentcc567b6634c3142125526f4875795c1b1e862838 (diff)
downloadspark-21fac5434174389e8b83a2f11341fa7c9e360bfd.tar.gz
spark-21fac5434174389e8b83a2f11341fa7c9e360bfd.tar.bz2
spark-21fac5434174389e8b83a2f11341fa7c9e360bfd.zip
[SPARK-11766][MLLIB] add toJson/fromJson to Vector/Vectors
This is to support JSON serialization of Param[Vector] in the pipeline API. It could be used for other purposes too. The schema is the same as `VectorUDT`. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #9751 from mengxr/SPARK-11766.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala45
1 files changed, 45 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index bd9badc03c..4dcf351df4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -24,6 +24,9 @@ import scala.annotation.varargs
import scala.collection.JavaConverters._
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render, parse => parseJson}
import org.apache.spark.SparkException
import org.apache.spark.annotation.{AlphaComponent, Since}
@@ -171,6 +174,12 @@ sealed trait Vector extends Serializable {
*/
@Since("1.5.0")
def argmax: Int
+
+ /**
+ * Converts the vector to a JSON string.
+ */
+ @Since("1.6.0")
+ def toJson: String
}
/**
@@ -339,6 +348,27 @@ object Vectors {
parseNumeric(NumericParser.parse(s))
}
+ /**
+ * Parses the JSON representation of a vector into a [[Vector]].
+ */
+ @Since("1.6.0")
+ def fromJson(json: String): Vector = {
+ implicit val formats = DefaultFormats
+ val jValue = parseJson(json)
+ (jValue \ "type").extract[Int] match {
+ case 0 => // sparse
+ val size = (jValue \ "size").extract[Int]
+ val indices = (jValue \ "indices").extract[Seq[Int]].toArray
+ val values = (jValue \ "values").extract[Seq[Double]].toArray
+ sparse(size, indices, values)
+ case 1 => // dense
+ val values = (jValue \ "values").extract[Seq[Double]].toArray
+ dense(values)
+ case _ =>
+ throw new IllegalArgumentException(s"Cannot parse $json into a vector.")
+ }
+ }
+
private[mllib] def parseNumeric(any: Any): Vector = {
any match {
case values: Array[Double] =>
@@ -650,6 +680,12 @@ class DenseVector @Since("1.0.0") (
maxIdx
}
}
+
+ @Since("1.6.0")
+ override def toJson: String = {
+ val jValue = ("type" -> 1) ~ ("values" -> values.toSeq)
+ compact(render(jValue))
+ }
}
@Since("1.3.0")
@@ -837,6 +873,15 @@ class SparseVector @Since("1.0.0") (
}.unzip
new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray)
}
+
+ @Since("1.6.0")
+ override def toJson: String = {
+ val jValue = ("type" -> 0) ~
+ ("size" -> size) ~
+ ("indices" -> indices.toSeq) ~
+ ("values" -> values.toSeq)
+ compact(render(jValue))
+ }
}
@Since("1.3.0")