diff options
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 37 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala | 2 |
2 files changed, 28 insertions, 11 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index d2ceb4a2b0..3bc5dce095 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -377,15 +377,15 @@ class SchemaRDD( def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan) /** - * Converts a JavaRDD to a PythonRDD. It is used by pyspark. + * Helper for converting a Row to a simple Array suitable for pyspark serialization. */ - private[sql] def javaToPython: JavaRDD[Array[Byte]] = { + private def rowToJArray(row: Row, structType: StructType): Array[Any] = { import scala.collection.Map def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { case (null, _) => null - case (obj: Row, struct: StructType) => rowToArray(obj, struct) + case (obj: Row, struct: StructType) => rowToJArray(obj, struct) case (seq: Seq[Any], array: ArrayType) => seq.map(x => toJava(x, array.elementType)).asJava @@ -402,23 +402,38 @@ class SchemaRDD( case (other, _) => other } - def rowToArray(row: Row, structType: StructType): Array[Any] = { - val fields = structType.fields.map(field => field.dataType) - row.zip(fields).map { - case (obj, dataType) => toJava(obj, dataType) - }.toArray - } + val fields = structType.fields.map(field => field.dataType) + row.zip(fields).map { + case (obj, dataType) => toJava(obj, dataType) + }.toArray + } + /** + * Converts a JavaRDD to a PythonRDD. It is used by pyspark. + */ + private[sql] def javaToPython: JavaRDD[Array[Byte]] = { val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) this.mapPartitions { iter => val pickle = new Pickler iter.map { row => - rowToArray(row, rowSchema) + rowToJArray(row, rowSchema) }.grouped(100).map(batched => pickle.dumps(batched.toArray)) } } /** + * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same + * format as javaToPython. It is used by pyspark. + */ + private[sql] def collectToPython: JList[Array[Byte]] = { + val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) + val pickle = new Pickler + new java.util.ArrayList(collect().map { row => + rowToJArray(row, rowSchema) + }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) + } + + /** * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value * of base RDD functions that do not change schema. * @@ -433,7 +448,7 @@ class SchemaRDD( } // ======================================================================= - // Overriden RDD actions + // Overridden RDD actions // ======================================================================= override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 4d799b4038..e7faba0c7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -112,6 +112,8 @@ class JavaSchemaRDD( new java.util.ArrayList(arr) } + override def count(): Long = baseSchemaRDD.count + override def take(num: Int): JList[Row] = { import scala.collection.JavaConversions._ val arr: java.util.Collection[Row] = baseSchemaRDD.take(num).toSeq.map(new Row(_)) |