aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <simonh@tw.ibm.com>2016-03-17 23:24:44 -0700
committerDavies Liu <davies.liu@gmail.com>2016-03-17 23:24:44 -0700
commit750ed64cd9db4f81a53caaf1fd6c8a6a0c07887d (patch)
treea5ef2b80a85c866ad2b6d2454eff6277220e21ea
parent10ef4f3e77b3a2a2770a1c869a236203560d4e6d (diff)
downloadspark-750ed64cd9db4f81a53caaf1fd6c8a6a0c07887d.tar.gz
spark-750ed64cd9db4f81a53caaf1fd6c8a6a0c07887d.tar.bz2
spark-750ed64cd9db4f81a53caaf1fd6c8a6a0c07887d.zip
[SPARK-13930] [SQL] Apply fast serialization on collect limit operator
## What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-13930 Recently the fast serialization has been introduced to collecting DataFrame/Dataset (#11664). The same technology can be used on collect limit operator too. ## How was this patch tested? Add a benchmark for collect limit to `BenchmarkWholeStageCodegen`. Without this patch: model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- collect limit 1 million 3413 / 3768 0.3 3255.0 1.0X collect limit 2 millions 9728 / 10440 0.1 9277.3 0.4X With this patch: model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- collect limit 1 million 833 / 1284 1.3 794.4 1.0X collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #11759 from viirya/execute-take.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala78
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala21
2 files changed, 71 insertions, 28 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index a392b53412..010ed7f500 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -219,48 +219,62 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
/**
- * Runs this query returning the result as an array.
+ * Packing the UnsafeRows into byte array for faster serialization.
+ * The byte arrays are in the following format:
+ * [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1]
+ *
+ * UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also
+ * compressed.
*/
- def executeCollect(): Array[InternalRow] = {
- // Packing the UnsafeRows into byte array for faster serialization.
- // The byte arrays are in the following format:
- // [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1]
- //
- // UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also
- // compressed.
- val byteArrayRdd = execute().mapPartitionsInternal { iter =>
+ private def getByteArrayRdd(n: Int = -1): RDD[Array[Byte]] = {
+ execute().mapPartitionsInternal { iter =>
+ var count = 0
val buffer = new Array[Byte](4 << 10) // 4K
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val bos = new ByteArrayOutputStream()
val out = new DataOutputStream(codec.compressedOutputStream(bos))
- while (iter.hasNext) {
+ while (iter.hasNext && (n < 0 || count < n)) {
val row = iter.next().asInstanceOf[UnsafeRow]
out.writeInt(row.getSizeInBytes)
row.writeToStream(out, buffer)
+ count += 1
}
out.writeInt(-1)
out.flush()
out.close()
Iterator(bos.toByteArray)
}
+ }
- // Collect the byte arrays back to driver, then decode them as UnsafeRows.
+ /**
+ * Decode the byte arrays back to UnsafeRows and put them into buffer.
+ */
+ private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = {
val nFields = schema.length
- val results = ArrayBuffer[InternalRow]()
+ val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
+ val bis = new ByteArrayInputStream(bytes)
+ val ins = new DataInputStream(codec.compressedInputStream(bis))
+ var sizeOfNextRow = ins.readInt()
+ while (sizeOfNextRow >= 0) {
+ val bs = new Array[Byte](sizeOfNextRow)
+ ins.readFully(bs)
+ val row = new UnsafeRow(nFields)
+ row.pointTo(bs, sizeOfNextRow)
+ buffer += row
+ sizeOfNextRow = ins.readInt()
+ }
+ }
+
+ /**
+ * Runs this query returning the result as an array.
+ */
+ def executeCollect(): Array[InternalRow] = {
+ val byteArrayRdd = getByteArrayRdd()
+
+ val results = ArrayBuffer[InternalRow]()
byteArrayRdd.collect().foreach { bytes =>
- val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
- val bis = new ByteArrayInputStream(bytes)
- val ins = new DataInputStream(codec.compressedInputStream(bis))
- var sizeOfNextRow = ins.readInt()
- while (sizeOfNextRow >= 0) {
- val bs = new Array[Byte](sizeOfNextRow)
- ins.readFully(bs)
- val row = new UnsafeRow(nFields)
- row.pointTo(bs, sizeOfNextRow)
- results += row
- sizeOfNextRow = ins.readInt()
- }
+ decodeUnsafeRows(bytes, results)
}
results.toArray
}
@@ -283,7 +297,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
return new Array[InternalRow](0)
}
- val childRDD = execute().map(_.copy())
+ val childRDD = getByteArrayRdd(n)
val buf = new ArrayBuffer[InternalRow]
val totalParts = childRDD.partitions.length
@@ -307,13 +321,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
val left = n - buf.size
val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val sc = sqlContext.sparkContext
- val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p)
+ val res = sc.runJob(childRDD,
+ (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p)
+
+ res.foreach { r =>
+ decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf)
+ }
- res.foreach(buf ++= _.take(n - buf.size))
partsScanned += p.size
}
- buf.toArray
+ if (buf.size > n) {
+ buf.take(n).toArray
+ } else {
+ buf.toArray
+ }
}
private[this] def isTesting: Boolean = sys.props.contains("spark.testing")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index b6051b07c8..d293ff66fb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -465,4 +465,25 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
collect 4 millions 3193 / 3895 0.3 3044.7 0.1X
*/
}
+
+ ignore("collect limit") {
+ val N = 1 << 20
+
+ val benchmark = new Benchmark("collect limit", N)
+ benchmark.addCase("collect limit 1 million") { iter =>
+ sqlContext.range(N * 4).limit(N).collect()
+ }
+ benchmark.addCase("collect limit 2 millions") { iter =>
+ sqlContext.range(N * 4).limit(N * 2).collect()
+ }
+ benchmark.run()
+
+ /**
+ model name : Westmere E56xx/L56xx/X56xx (Nehalem-C)
+ collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ collect limit 1 million 833 / 1284 1.3 794.4 1.0X
+ collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X
+ */
+ }
}