diff options
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala | 5 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala | 5 | ||||
-rw-r--r-- | python/pyspark/tests.py | 19 |
3 files changed, 28 insertions, 1 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 5ba66178e2..c9181a29d4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -138,6 +138,11 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { mapWritable.put(convertToWritable(k), convertToWritable(v)) } mapWritable + case array: Array[Any] => { + val arrayWriteable = new ArrayWritable(classOf[Writable]) + arrayWriteable.set(array.map(convertToWritable(_))) + arrayWriteable + } case other => throw new SparkException( s"Data of type ${other.getClass.getName} cannot be used") } diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index a4153aaa92..19ca2bb613 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -153,7 +153,10 @@ private[spark] object SerDeUtil extends Logging { iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala + obj match { + case array: Array[Any] => array.toSeq + case _ => obj.asInstanceOf[JArrayList[_]].asScala + } } else { Seq(obj) } diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index e8e207af46..e694ffcff5 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -714,6 +714,25 @@ class RDDTests(ReusedPySparkTestCase): wr_s21 = rdd.sample(True, 0.4, 21).collect() self.assertNotEqual(set(wr_s11), set(wr_s21)) + def test_multiple_python_java_RDD_conversions(self): + # Regression test for SPARK-5361 + data = [ + (u'1', {u'director': u'David Lean'}), + (u'2', {u'director': u'Andrew Dominik'}) + ] + from pyspark.rdd import RDD + data_rdd = self.sc.parallelize(data) + data_java_rdd = data_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + # conversion between python and java RDD threw exceptions + data_java_rdd = converted_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + class ProfilerTests(PySparkTestCase): |