aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib-local/pom.xml4
-rw-r--r--mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala41
-rw-r--r--mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala62
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/linalg/JsonVectorConverterSuite.scala41
5 files changed, 103 insertions, 62 deletions
diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml
index 60b615a07f..68f15dd905 100644
--- a/mllib-local/pom.xml
+++ b/mllib-local/pom.xml
@@ -49,10 +49,6 @@
<scope>test</scope>
</dependency>
<dependency>
- <groupId>org.json4s</groupId>
- <artifactId>json4s-jackson_${scala.binary.version}</artifactId>
- </dependency>
- <dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
index 4275a22ae0..c0d112d2c5 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
@@ -24,9 +24,6 @@ import scala.annotation.varargs
import scala.collection.JavaConverters._
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
-import org.json4s.DefaultFormats
-import org.json4s.JsonDSL._
-import org.json4s.jackson.JsonMethods.{compact, parse => parseJson, render}
/**
* Represents a numeric vector, whose index type is Int and value type is Double.
@@ -153,11 +150,6 @@ sealed trait Vector extends Serializable {
* Returns -1 if vector has length 0.
*/
def argmax: Int
-
- /**
- * Converts the vector to a JSON string.
- */
- def toJson: String
}
/**
@@ -234,26 +226,6 @@ object Vectors {
}
/**
- * Parses the JSON representation of a vector into a [[Vector]].
- */
- def fromJson(json: String): Vector = {
- implicit val formats = DefaultFormats
- val jValue = parseJson(json)
- (jValue \ "type").extract[Int] match {
- case 0 => // sparse
- val size = (jValue \ "size").extract[Int]
- val indices = (jValue \ "indices").extract[Seq[Int]].toArray
- val values = (jValue \ "values").extract[Seq[Double]].toArray
- sparse(size, indices, values)
- case 1 => // dense
- val values = (jValue \ "values").extract[Seq[Double]].toArray
- dense(values)
- case _ =>
- throw new IllegalArgumentException(s"Cannot parse $json into a vector.")
- }
- }
-
- /**
* Creates a vector instance from a breeze vector.
*/
private[spark] def fromBreeze(breezeVector: BV[Double]): Vector = {
@@ -541,11 +513,6 @@ class DenseVector (val values: Array[Double]) extends Vector {
maxIdx
}
}
-
- override def toJson: String = {
- val jValue = ("type" -> 1) ~ ("values" -> values.toSeq)
- compact(render(jValue))
- }
}
object DenseVector {
@@ -724,14 +691,6 @@ class SparseVector (
}.unzip
new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray)
}
-
- override def toJson: String = {
- val jValue = ("type" -> 0) ~
- ("size" -> size) ~
- ("indices" -> indices.toSeq) ~
- ("values" -> values.toSeq)
- compact(render(jValue))
- }
}
object SparseVector {
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
index 504be36413..887814b5e7 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.ml.linalg
import scala.util.Random
import breeze.linalg.{squaredDistance => breezeSquaredDistance, DenseMatrix => BDM}
-import org.json4s.jackson.JsonMethods.{parse => parseJson}
import org.apache.spark.ml.SparkMLFunSuite
import org.apache.spark.ml.util.TestingUtils._
@@ -339,20 +338,4 @@ class VectorsSuite extends SparkMLFunSuite {
assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2)))
assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4)))
}
-
- test("toJson/fromJson") {
- val sv0 = Vectors.sparse(0, Array.empty, Array.empty)
- val sv1 = Vectors.sparse(1, Array.empty, Array.empty)
- val sv2 = Vectors.sparse(2, Array(1), Array(2.0))
- val dv0 = Vectors.dense(Array.empty[Double])
- val dv1 = Vectors.dense(1.0)
- val dv2 = Vectors.dense(0.0, 2.0)
- for (v <- Seq(sv0, sv1, sv2, dv0, dv1, dv2)) {
- val json = v.toJson
- parseJson(json) // `json` should be a valid JSON string
- val u = Vectors.fromJson(json)
- assert(u.getClass === v.getClass, "toJson/fromJson should preserve vector types.")
- assert(u === v, "toJson/fromJson should preserve vector values.")
- }
- }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala
new file mode 100644
index 0000000000..781e69f8d6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.linalg
+
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, parse => parseJson, render}
+
+private[ml] object JsonVectorConverter {
+
+ /**
+ * Parses the JSON representation of a vector into a [[Vector]].
+ */
+ def fromJson(json: String): Vector = {
+ implicit val formats = DefaultFormats
+ val jValue = parseJson(json)
+ (jValue \ "type").extract[Int] match {
+ case 0 => // sparse
+ val size = (jValue \ "size").extract[Int]
+ val indices = (jValue \ "indices").extract[Seq[Int]].toArray
+ val values = (jValue \ "values").extract[Seq[Double]].toArray
+ Vectors.sparse(size, indices, values)
+ case 1 => // dense
+ val values = (jValue \ "values").extract[Seq[Double]].toArray
+ Vectors.dense(values)
+ case _ =>
+ throw new IllegalArgumentException(s"Cannot parse $json into a vector.")
+ }
+ }
+
+ /**
+ * Coverts the vector to a JSON string.
+ */
+ def toJson(v: Vector): String = {
+ v match {
+ case SparseVector(size, indices, values) =>
+ val jValue = ("type" -> 0) ~
+ ("size" -> size) ~
+ ("indices" -> indices.toSeq) ~
+ ("values" -> values.toSeq)
+ compact(render(jValue))
+ case DenseVector(values) =>
+ val jValue = ("type" -> 1) ~ ("values" -> values.toSeq)
+ compact(render(jValue))
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/JsonVectorConverterSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/JsonVectorConverterSuite.scala
new file mode 100644
index 0000000000..53d57f0f6e
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/JsonVectorConverterSuite.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.linalg
+
+import org.json4s.jackson.JsonMethods.parse
+
+import org.apache.spark.SparkFunSuite
+
+class JsonVectorConverterSuite extends SparkFunSuite {
+
+ test("toJson/fromJson") {
+ val sv0 = Vectors.sparse(0, Array.empty, Array.empty)
+ val sv1 = Vectors.sparse(1, Array.empty, Array.empty)
+ val sv2 = Vectors.sparse(2, Array(1), Array(2.0))
+ val dv0 = Vectors.dense(Array.empty[Double])
+ val dv1 = Vectors.dense(1.0)
+ val dv2 = Vectors.dense(0.0, 2.0)
+ for (v <- Seq(sv0, sv1, sv2, dv0, dv1, dv2)) {
+ val json = JsonVectorConverter.toJson(v)
+ parse(json) // `json` should be a valid JSON string
+ val u = JsonVectorConverter.fromJson(json)
+ assert(u.getClass === v.getClass, "toJson/fromJson should preserve vector types.")
+ assert(u === v, "toJson/fromJson should preserve vector values.")
+ }
+ }
+}