aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-04-01 18:17:07 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-01 18:31:17 -0700
commit0d1e476683b016fa87a1eac3982c25c701003f96 (patch)
treec23d8c3dbfbc6a0712fc956cbb7e7a4a8942a5ca
parent98f72dfc17853b570d05c20e97c78919682b6df6 (diff)
downloadspark-0d1e476683b016fa87a1eac3982c25c701003f96.tar.gz
spark-0d1e476683b016fa87a1eac3982c25c701003f96.tar.bz2
spark-0d1e476683b016fa87a1eac3982c25c701003f96.zip
[SPARK-6660][MLLIB] pythonToJava doesn't recognize object arrays
davies Author: Xiangrui Meng <meng@databricks.com> Closes #5318 from mengxr/SPARK-6660 and squashes the following commits: 0f66ec2 [Xiangrui Meng] recognize object arrays ad8c42f [Xiangrui Meng] add a test for SPARK-6660 (cherry picked from commit 4815bc2128c7f6d4d21da730b8c72da087233b34) Signed-off-by: Xiangrui Meng <meng@databricks.com> Conflicts: python/pyspark/mllib/tests.py
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala5
-rw-r--r--python/pyspark/mllib/tests.py8
2 files changed, 12 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 443a8a40ab..74cbd6f5ab 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -1087,7 +1087,10 @@ private[spark] object SerDe extends Serializable {
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
- obj.asInstanceOf[JArrayList[_]].asScala
+ obj match {
+ case list: JArrayList[_] => list.asScala
+ case arr: Array[_] => arr
+ }
} else {
Seq(obj)
}
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 155019638f..113f2163fb 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -36,6 +36,7 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
+from pyspark.mllib.common import _to_java_object_rdd
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
DenseMatrix, Vectors, Matrices
from pyspark.mllib.regression import LabeledPoint
@@ -620,6 +621,13 @@ class ChiSqTestTests(PySparkTestCase):
self.assertEqual(len(chi), num_cols)
self.assertIsNotNone(chi[1000])
+
+class SerDeTest(PySparkTestCase):
+ def test_to_java_object_rdd(self): # SPARK-6660
+ data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
+ self.assertEqual(_to_java_object_rdd(data).count(), 10)
+
+
if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"