aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala2
-rw-r--r--python/pyspark/sql/dataframe.py5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala)14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala16
4 files changed, 25 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 19be093903..8464b578ed 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -633,7 +633,7 @@ private[spark] object PythonRDD extends Logging {
*
* The thread will terminate after all the data are sent or any exceptions happen.
*/
- private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
+ def serveIterator[T](items: Iterator[T], threadName: String): Int = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
// Close the socket if no connection in 3 seconds
serverSocket.setSoTimeout(3000)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 80f8d8a0eb..b09422aade 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -300,7 +300,10 @@ class DataFrame(object):
>>> df.take(2)
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
- return self.limit(num).collect()
+ with SCCallSiteSync(self._sc) as css:
+ port = self._sc._jvm.org.apache.spark.sql.execution.EvaluatePython.takeAndServe(
+ self._jdf, num)
+ return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
@ignore_unicode_prefix
@since(1.3)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
index c35c726bfc..d6aaf424a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
@@ -28,7 +28,8 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -118,6 +119,17 @@ object EvaluatePython {
def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython =
new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
+ def takeAndServe(df: DataFrame, n: Int): Int = {
+ registerPicklers()
+ // This is an annoying hack - we should refactor the code so executeCollect and executeTake
+ // returns InternalRow rather than Row.
+ val converter = CatalystTypeConverters.createToCatalystConverter(df.schema)
+ val iter = new SerDeUtil.AutoBatchedPickler(df.take(n).iterator.map { row =>
+ EvaluatePython.toJava(converter(row).asInstanceOf[InternalRow], df.schema)
+ })
+ PythonRDD.serveIterator(iter, s"serve-DataFrame")
+ }
+
/**
* Helper for converting from Catalyst type to java type suitable for Pyrolite.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index 2fdd798b44..963e6030c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -39,22 +39,20 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
- override def serialize(obj: Any): Seq[Double] = {
+ override def serialize(obj: Any): GenericArrayData = {
obj match {
case p: ExamplePoint =>
- Seq(p.x, p.y)
+ val output = new Array[Any](2)
+ output(0) = p.x
+ output(1) = p.y
+ new GenericArrayData(output)
}
}
override def deserialize(datum: Any): ExamplePoint = {
datum match {
- case values: Seq[_] =>
- val xy = values.asInstanceOf[Seq[Double]]
- assert(xy.length == 2)
- new ExamplePoint(xy(0), xy(1))
- case values: util.ArrayList[_] =>
- val xy = values.asInstanceOf[util.ArrayList[Double]].asScala
- new ExamplePoint(xy(0), xy(1))
+ case values: ArrayData =>
+ new ExamplePoint(values.getDouble(0), values.getDouble(1))
}
}