diff options
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 2 | ||||
-rw-r--r-- | python/pyspark/sql/dataframe.py | 5 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala) | 14 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala | 16 |
4 files changed, 25 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 19be093903..8464b578ed 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -633,7 +633,7 @@ private[spark] object PythonRDD extends Logging { * * The thread will terminate after all the data are sent or any exceptions happen. */ - private def serveIterator[T](items: Iterator[T], threadName: String): Int = { + def serveIterator[T](items: Iterator[T], threadName: String): Int = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 3 seconds serverSocket.setSoTimeout(3000) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 80f8d8a0eb..b09422aade 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -300,7 +300,10 @@ class DataFrame(object): >>> df.take(2) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ - return self.limit(num).collect() + with SCCallSiteSync(self._sc) as css: + port = self._sc._jvm.org.apache.spark.sql.execution.EvaluatePython.takeAndServe( + self._jdf, num) + return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix @since(1.3) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index c35c726bfc..d6aaf424a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -28,7 +28,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -118,6 +119,17 @@ object EvaluatePython { def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) + def takeAndServe(df: DataFrame, n: Int): Int = { + registerPicklers() + // This is an annoying hack - we should refactor the code so executeCollect and executeTake + // returns InternalRow rather than Row. + val converter = CatalystTypeConverters.createToCatalystConverter(df.schema) + val iter = new SerDeUtil.AutoBatchedPickler(df.take(n).iterator.map { row => + EvaluatePython.toJava(converter(row).asInstanceOf[InternalRow], df.schema) + }) + PythonRDD.serveIterator(iter, s"serve-DataFrame") + } + /** * Helper for converting from Catalyst type to java type suitable for Pyrolite. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 2fdd798b44..963e6030c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -39,22 +39,20 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" - override def serialize(obj: Any): Seq[Double] = { + override def serialize(obj: Any): GenericArrayData = { obj match { case p: ExamplePoint => - Seq(p.x, p.y) + val output = new Array[Any](2) + output(0) = p.x + output(1) = p.y + new GenericArrayData(output) } } override def deserialize(datum: Any): ExamplePoint = { datum match { - case values: Seq[_] => - val xy = values.asInstanceOf[Seq[Double]] - assert(xy.length == 2) - new ExamplePoint(xy(0), xy(1)) - case values: util.ArrayList[_] => - val xy = values.asInstanceOf[util.ArrayList[Double]].asScala - new ExamplePoint(xy(0), xy(1)) + case values: ArrayData => + new ExamplePoint(values.getDouble(0), values.getDouble(1)) } } |