aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-14 22:32:22 -0700
committerReynold Xin <rxin@databricks.com>2016-03-14 22:32:22 -0700
commitf72743d971a38d3d08984ef4b66e0955945d2f58 (patch)
tree4cc35ef1a686d97736463808474f2d8843e9a6e9 /sql
parent9256840cb631cad50852b2b218a1ac71b567084a (diff)
downloadspark-f72743d971a38d3d08984ef4b66e0955945d2f58.tar.gz
spark-f72743d971a38d3d08984ef4b66e0955945d2f58.tar.bz2
spark-f72743d971a38d3d08984ef4b66e0955945d2f58.zip
[SPARK-13353][SQL] fast serialization for collecting DataFrame/Dataset
## What changes were proposed in this pull request? When we call DataFrame/Dataset.collect(), Java serializer (or Kryo Serializer) will be used to serialize the UnsafeRows in executor, then deserialize them into UnsafeRows in driver. Java serializer (and Kyro serializer) are slow on millions rows, because they try to find out the same rows, but usually there is no same rows. This PR will serialize the UnsafeRows as byte array by packing them together, then Java serializer (or Kyro serializer) serialize the bytes very fast (there are fewer blocks and byte array are not compared by content). The UnsafeRow format is highly compressible, the serialized bytes are also compressed (configurable by spark.io.compression.codec). ## How was this patch tested? Existing unit tests. Add a benchmark for collect, before this patch: ``` Intel(R) Core(TM) i7-4558U CPU 2.80GHz collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- collect 1 million 3991 / 4311 0.3 3805.7 1.0X collect 2 millions 10083 / 10637 0.1 9616.0 0.4X collect 4 millions 29551 / 30072 0.0 28182.3 0.1X ``` ``` Intel(R) Core(TM) i7-4558U CPU 2.80GHz collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- collect 1 million 775 / 1170 1.4 738.9 1.0X collect 2 millions 1153 / 1758 0.9 1099.3 0.7X collect 4 millions 4451 / 5124 0.2 4244.9 0.2X ``` We can see about 5-7X speedup. Author: Davies Liu <davies@databricks.com> Closes #11664 from davies/serialize_row.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala47
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala25
4 files changed, 74 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 1ea7db0388..b5079cf276 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1797,14 +1797,14 @@ class Dataset[T] private[sql](
*/
def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ =>
withNewExecutionId {
- val values = queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
+ val values = queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow)
java.util.Arrays.asList(values : _*)
}
}
private def collect(needCallback: Boolean): Array[T] = {
def execute(): Array[T] = withNewExecutionId {
- queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
+ queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow)
}
if (needCallback) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index a92c99e06f..e04683c499 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -17,14 +17,15 @@
package org.apache.spark.sql.execution
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration._
-import org.apache.spark.Logging
-import org.apache.spark.broadcast
+import org.apache.spark.{broadcast, Logging, SparkEnv}
+import org.apache.spark.io.CompressionCodec
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
@@ -220,7 +221,47 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* Runs this query returning the result as an array.
*/
def executeCollect(): Array[InternalRow] = {
- execute().map(_.copy()).collect()
+ // Packing the UnsafeRows into byte array for faster serialization.
+ // The byte arrays are in the following format:
+ // [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1]
+ //
+ // UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also
+ // compressed.
+ val byteArrayRdd = execute().mapPartitionsInternal { iter =>
+ val buffer = new Array[Byte](4 << 10) // 4K
+ val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
+ val bos = new ByteArrayOutputStream()
+ val out = new DataOutputStream(codec.compressedOutputStream(bos))
+ while (iter.hasNext) {
+ val row = iter.next().asInstanceOf[UnsafeRow]
+ out.writeInt(row.getSizeInBytes)
+ row.writeToStream(out, buffer)
+ }
+ out.writeInt(-1)
+ out.flush()
+ out.close()
+ Iterator(bos.toByteArray)
+ }
+
+ // Collect the byte arrays back to driver, then decode them as UnsafeRows.
+ val nFields = schema.length
+ val results = ArrayBuffer[InternalRow]()
+
+ byteArrayRdd.collect().foreach { bytes =>
+ val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
+ val bis = new ByteArrayInputStream(bytes)
+ val ins = new DataInputStream(codec.compressedInputStream(bis))
+ var sizeOfNextRow = ins.readInt()
+ while (sizeOfNextRow >= 0) {
+ val bs = new Array[Byte](sizeOfNextRow)
+ ins.readFully(bs)
+ val row = new UnsafeRow(nFields)
+ row.pointTo(bs, sizeOfNextRow)
+ results += row
+ sizeOfNextRow = ins.readInt()
+ }
+ }
+ results.toArray
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
index 2c4b4f80ff..b1987c6908 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
@@ -29,7 +29,9 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan {
override protected def doExecute(): RDD[InternalRow] = {
val str = Literal("so fast").value
val row = new GenericInternalRow(Array[Any](str))
- sparkContext.parallelize(Seq(row))
+ val unsafeProj = UnsafeProjection.create(schema)
+ val unsafeRow = unsafeProj(row).copy()
+ sparkContext.parallelize(Seq(unsafeRow))
}
override def producedAttributes: AttributeSet = outputSet
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 2d3e34d0e1..9f33e4ab62 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -428,4 +428,29 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
*/
benchmark.run()
}
+
+ ignore("collect") {
+ val N = 1 << 20
+
+ val benchmark = new Benchmark("collect", N)
+ benchmark.addCase("collect 1 million") { iter =>
+ sqlContext.range(N).collect()
+ }
+ benchmark.addCase("collect 2 millions") { iter =>
+ sqlContext.range(N * 2).collect()
+ }
+ benchmark.addCase("collect 4 millions") { iter =>
+ sqlContext.range(N * 4).collect()
+ }
+ benchmark.run()
+
+ /**
+ * Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ collect 1 million 775 / 1170 1.4 738.9 1.0X
+ collect 2 millions 1153 / 1758 0.9 1099.3 0.7X
+ collect 4 millions 4451 / 5124 0.2 4244.9 0.2X
+ */
+ }
}