aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala2
2 files changed, 28 insertions, 11 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index d2ceb4a2b0..3bc5dce095 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -377,15 +377,15 @@ class SchemaRDD(
def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan)
/**
- * Converts a JavaRDD to a PythonRDD. It is used by pyspark.
+ * Helper for converting a Row to a simple Array suitable for pyspark serialization.
*/
- private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
+ private def rowToJArray(row: Row, structType: StructType): Array[Any] = {
import scala.collection.Map
def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
case (null, _) => null
- case (obj: Row, struct: StructType) => rowToArray(obj, struct)
+ case (obj: Row, struct: StructType) => rowToJArray(obj, struct)
case (seq: Seq[Any], array: ArrayType) =>
seq.map(x => toJava(x, array.elementType)).asJava
@@ -402,23 +402,38 @@ class SchemaRDD(
case (other, _) => other
}
- def rowToArray(row: Row, structType: StructType): Array[Any] = {
- val fields = structType.fields.map(field => field.dataType)
- row.zip(fields).map {
- case (obj, dataType) => toJava(obj, dataType)
- }.toArray
- }
+ val fields = structType.fields.map(field => field.dataType)
+ row.zip(fields).map {
+ case (obj, dataType) => toJava(obj, dataType)
+ }.toArray
+ }
+ /**
+ * Converts a JavaRDD to a PythonRDD. It is used by pyspark.
+ */
+ private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
this.mapPartitions { iter =>
val pickle = new Pickler
iter.map { row =>
- rowToArray(row, rowSchema)
+ rowToJArray(row, rowSchema)
}.grouped(100).map(batched => pickle.dumps(batched.toArray))
}
}
/**
+ * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same
+ * format as javaToPython. It is used by pyspark.
+ */
+ private[sql] def collectToPython: JList[Array[Byte]] = {
+ val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
+ val pickle = new Pickler
+ new java.util.ArrayList(collect().map { row =>
+ rowToJArray(row, rowSchema)
+ }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
+ }
+
+ /**
* Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value
* of base RDD functions that do not change schema.
*
@@ -433,7 +448,7 @@ class SchemaRDD(
}
// =======================================================================
- // Overriden RDD actions
+ // Overridden RDD actions
// =======================================================================
override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
index 4d799b4038..e7faba0c7f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
@@ -112,6 +112,8 @@ class JavaSchemaRDD(
new java.util.ArrayList(arr)
}
+ override def count(): Long = baseSchemaRDD.count
+
override def take(num: Int): JList[Row] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[Row] = baseSchemaRDD.take(num).toSeq.map(new Row(_))