aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <simonh@tw.ibm.com>2016-06-13 19:59:53 -0700
committerXiangrui Meng <meng@databricks.com>2016-06-13 19:59:53 -0700
commitbaa3e633e18c47b12e79fe3ddc01fc8ec010f096 (patch)
tree83c91014b9d46fc9efc4bf1f5dcef5cee1fe184a /mllib/src/test
parent5827b65e28da168286c771c53a38620d79f5e74f (diff)
downloadspark-baa3e633e18c47b12e79fe3ddc01fc8ec010f096.tar.gz
spark-baa3e633e18c47b12e79fe3ddc01fc8ec010f096.tar.bz2
spark-baa3e633e18c47b12e79fe3ddc01fc8ec010f096.zip
[SPARK-15364][ML][PYSPARK] Implement PySpark picklers for ml.Vector and ml.Matrix under spark.ml.python
## What changes were proposed in this pull request? Now we have PySpark picklers for new and old vector/matrix, individually. However, they are all implemented under `PythonMLlibAPI`. To separate spark.mllib from spark.ml, we should implement the picklers of new vector/matrix under `spark.ml.python` instead. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #13219 from viirya/pyspark-pickler-ml.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala72
1 files changed, 72 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala
new file mode 100644
index 0000000000..5eaef9aabd
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.python
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, Vectors}
+
+class MLSerDeSuite extends SparkFunSuite {
+
+ MLSerDe.initialize()
+
+ test("pickle vector") {
+ val vectors = Seq(
+ Vectors.dense(Array.empty[Double]),
+ Vectors.dense(0.0),
+ Vectors.dense(0.0, -2.0),
+ Vectors.sparse(0, Array.empty[Int], Array.empty[Double]),
+ Vectors.sparse(1, Array.empty[Int], Array.empty[Double]),
+ Vectors.sparse(2, Array(1), Array(-2.0)))
+ vectors.foreach { v =>
+ val u = MLSerDe.loads(MLSerDe.dumps(v))
+ assert(u.getClass === v.getClass)
+ assert(u === v)
+ }
+ }
+
+ test("pickle double") {
+ for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) {
+ val deser = MLSerDe.loads(MLSerDe.dumps(x.asInstanceOf[AnyRef])).asInstanceOf[Double]
+ // We use `equals` here for comparison because we cannot use `==` for NaN
+ assert(x.equals(deser))
+ }
+ }
+
+ test("pickle matrix") {
+ val values = Array[Double](0, 1.2, 3, 4.56, 7, 8)
+ val matrix = Matrices.dense(2, 3, values)
+ val nm = MLSerDe.loads(MLSerDe.dumps(matrix)).asInstanceOf[DenseMatrix]
+ assert(matrix === nm)
+
+ // Test conversion for empty matrix
+ val empty = Array[Double]()
+ val emptyMatrix = Matrices.dense(0, 0, empty)
+ val ne = MLSerDe.loads(MLSerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix]
+ assert(emptyMatrix == ne)
+
+ val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4))
+ val nsm = MLSerDe.loads(MLSerDe.dumps(sm)).asInstanceOf[SparseMatrix]
+ assert(sm.toArray === nsm.toArray)
+
+ val smt = new SparseMatrix(
+ 3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9),
+ isTransposed = true)
+ val nsmt = MLSerDe.loads(MLSerDe.dumps(smt)).asInstanceOf[SparseMatrix]
+ assert(smt.toArray === nsmt.toArray)
+ }
+}