aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorAaron Staple <aaron.staple@gmail.com>2014-09-16 11:45:35 -0700
committerMichael Armbrust <michael@databricks.com>2014-09-16 11:45:35 -0700
commit8e7ae477ba40a064d27cf149aa211ff6108fe239 (patch)
tree3f30546913a7e1ca882e65fcbac721ed0ad36258 /sql
parent30f288ae34a67307aa45b7aecbd0d02a0a14fe69 (diff)
downloadspark-8e7ae477ba40a064d27cf149aa211ff6108fe239.tar.gz
spark-8e7ae477ba40a064d27cf149aa211ff6108fe239.tar.bz2
spark-8e7ae477ba40a064d27cf149aa211ff6108fe239.zip
[SPARK-2314][SQL] Override collect and take in python library, and count in java library, with optimized versions.
SchemaRDD overrides RDD functions, including collect, count, and take, with optimized versions making use of the query optimizer. The java and python interface classes wrapping SchemaRDD need to ensure the optimized versions are called as well. This patch overrides relevant calls in the python and java interfaces with optimized versions. Adds a new Row serialization pathway between python and java, based on JList[Array[Byte]] versus the existing RDD[Array[Byte]]. I wasn’t overjoyed about doing this, but I noticed that some QueryPlans implement optimizations in executeCollect(), which outputs an Array[Row] rather than the typical RDD[Row] that can be shipped to python using the existing serialization code. To me it made sense to ship the Array[Row] over to python directly instead of converting it back to an RDD[Row] just for the purpose of sending the Rows to python using the existing serialization code. Author: Aaron Staple <aaron.staple@gmail.com> Closes #1592 from staple/SPARK-2314 and squashes the following commits: 89ff550 [Aaron Staple] Merge with master. 6bb7b6c [Aaron Staple] Fix typo. b56d0ac [Aaron Staple] [SPARK-2314][SQL] Override count in JavaSchemaRDD, forwarding to SchemaRDD's count. 0fc9d40 [Aaron Staple] Fix comment typos. f03cdfa [Aaron Staple] [SPARK-2314][SQL] Override collect and take in sql.py, forwarding to SchemaRDD's collect.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala2
2 files changed, 28 insertions, 11 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index d2ceb4a2b0..3bc5dce095 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -377,15 +377,15 @@ class SchemaRDD(
def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan)
/**
- * Converts a JavaRDD to a PythonRDD. It is used by pyspark.
+ * Helper for converting a Row to a simple Array suitable for pyspark serialization.
*/
- private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
+ private def rowToJArray(row: Row, structType: StructType): Array[Any] = {
import scala.collection.Map
def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
case (null, _) => null
- case (obj: Row, struct: StructType) => rowToArray(obj, struct)
+ case (obj: Row, struct: StructType) => rowToJArray(obj, struct)
case (seq: Seq[Any], array: ArrayType) =>
seq.map(x => toJava(x, array.elementType)).asJava
@@ -402,23 +402,38 @@ class SchemaRDD(
case (other, _) => other
}
- def rowToArray(row: Row, structType: StructType): Array[Any] = {
- val fields = structType.fields.map(field => field.dataType)
- row.zip(fields).map {
- case (obj, dataType) => toJava(obj, dataType)
- }.toArray
- }
+ val fields = structType.fields.map(field => field.dataType)
+ row.zip(fields).map {
+ case (obj, dataType) => toJava(obj, dataType)
+ }.toArray
+ }
+ /**
+ * Converts a JavaRDD to a PythonRDD. It is used by pyspark.
+ */
+ private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
this.mapPartitions { iter =>
val pickle = new Pickler
iter.map { row =>
- rowToArray(row, rowSchema)
+ rowToJArray(row, rowSchema)
}.grouped(100).map(batched => pickle.dumps(batched.toArray))
}
}
/**
+ * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same
+ * format as javaToPython. It is used by pyspark.
+ */
+ private[sql] def collectToPython: JList[Array[Byte]] = {
+ val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
+ val pickle = new Pickler
+ new java.util.ArrayList(collect().map { row =>
+ rowToJArray(row, rowSchema)
+ }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
+ }
+
+ /**
* Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value
* of base RDD functions that do not change schema.
*
@@ -433,7 +448,7 @@ class SchemaRDD(
}
// =======================================================================
- // Overriden RDD actions
+ // Overridden RDD actions
// =======================================================================
override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
index 4d799b4038..e7faba0c7f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
@@ -112,6 +112,8 @@ class JavaSchemaRDD(
new java.util.ArrayList(arr)
}
+ override def count(): Long = baseSchemaRDD.count
+
override def take(num: Int): JList[Row] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[Row] = baseSchemaRDD.take(num).toSeq.map(new Row(_))