aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala47
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala25
4 files changed, 74 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 1ea7db0388..b5079cf276 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1797,14 +1797,14 @@ class Dataset[T] private[sql](
*/
def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ =>
withNewExecutionId {
- val values = queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
+ val values = queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow)
java.util.Arrays.asList(values : _*)
}
}
private def collect(needCallback: Boolean): Array[T] = {
def execute(): Array[T] = withNewExecutionId {
- queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
+ queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow)
}
if (needCallback) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index a92c99e06f..e04683c499 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -17,14 +17,15 @@
package org.apache.spark.sql.execution
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration._
-import org.apache.spark.Logging
-import org.apache.spark.broadcast
+import org.apache.spark.{broadcast, Logging, SparkEnv}
+import org.apache.spark.io.CompressionCodec
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
@@ -220,7 +221,47 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* Runs this query returning the result as an array.
*/
def executeCollect(): Array[InternalRow] = {
- execute().map(_.copy()).collect()
+ // Packing the UnsafeRows into byte array for faster serialization.
+ // The byte arrays are in the following format:
+ // [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1]
+ //
+ // UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also
+ // compressed.
+ val byteArrayRdd = execute().mapPartitionsInternal { iter =>
+ val buffer = new Array[Byte](4 << 10) // 4K
+ val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
+ val bos = new ByteArrayOutputStream()
+ val out = new DataOutputStream(codec.compressedOutputStream(bos))
+ while (iter.hasNext) {
+ val row = iter.next().asInstanceOf[UnsafeRow]
+ out.writeInt(row.getSizeInBytes)
+ row.writeToStream(out, buffer)
+ }
+ out.writeInt(-1)
+ out.flush()
+ out.close()
+ Iterator(bos.toByteArray)
+ }
+
+ // Collect the byte arrays back to driver, then decode them as UnsafeRows.
+ val nFields = schema.length
+ val results = ArrayBuffer[InternalRow]()
+
+ byteArrayRdd.collect().foreach { bytes =>
+ val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
+ val bis = new ByteArrayInputStream(bytes)
+ val ins = new DataInputStream(codec.compressedInputStream(bis))
+ var sizeOfNextRow = ins.readInt()
+ while (sizeOfNextRow >= 0) {
+ val bs = new Array[Byte](sizeOfNextRow)
+ ins.readFully(bs)
+ val row = new UnsafeRow(nFields)
+ row.pointTo(bs, sizeOfNextRow)
+ results += row
+ sizeOfNextRow = ins.readInt()
+ }
+ }
+ results.toArray
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
index 2c4b4f80ff..b1987c6908 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
@@ -29,7 +29,9 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan {
override protected def doExecute(): RDD[InternalRow] = {
val str = Literal("so fast").value
val row = new GenericInternalRow(Array[Any](str))
- sparkContext.parallelize(Seq(row))
+ val unsafeProj = UnsafeProjection.create(schema)
+ val unsafeRow = unsafeProj(row).copy()
+ sparkContext.parallelize(Seq(unsafeRow))
}
override def producedAttributes: AttributeSet = outputSet
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 2d3e34d0e1..9f33e4ab62 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -428,4 +428,29 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
*/
benchmark.run()
}
+
+ ignore("collect") {
+ val N = 1 << 20
+
+ val benchmark = new Benchmark("collect", N)
+ benchmark.addCase("collect 1 million") { iter =>
+ sqlContext.range(N).collect()
+ }
+ benchmark.addCase("collect 2 millions") { iter =>
+ sqlContext.range(N * 2).collect()
+ }
+ benchmark.addCase("collect 4 millions") { iter =>
+ sqlContext.range(N * 4).collect()
+ }
+ benchmark.run()
+
+ /**
+ * Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ collect 1 million 775 / 1170 1.4 738.9 1.0X
+ collect 2 millions 1153 / 1758 0.9 1099.3 0.7X
+ collect 4 millions 4451 / 5124 0.2 4244.9 0.2X
+ */
+ }
}