diff options
author | Liang-Chi Hsieh <simonh@tw.ibm.com> | 2016-03-17 23:24:44 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-03-17 23:24:44 -0700 |
commit | 750ed64cd9db4f81a53caaf1fd6c8a6a0c07887d (patch) | |
tree | a5ef2b80a85c866ad2b6d2454eff6277220e21ea | |
parent | 10ef4f3e77b3a2a2770a1c869a236203560d4e6d (diff) | |
download | spark-750ed64cd9db4f81a53caaf1fd6c8a6a0c07887d.tar.gz spark-750ed64cd9db4f81a53caaf1fd6c8a6a0c07887d.tar.bz2 spark-750ed64cd9db4f81a53caaf1fd6c8a6a0c07887d.zip |
[SPARK-13930] [SQL] Apply fast serialization on collect limit operator
## What changes were proposed in this pull request?
JIRA: https://issues.apache.org/jira/browse/SPARK-13930
Recently the fast serialization has been introduced to collecting DataFrame/Dataset (#11664). The same technology can be used on collect limit operator too.
## How was this patch tested?
Add a benchmark for collect limit to `BenchmarkWholeStageCodegen`.
Without this patch:
model name : Westmere E56xx/L56xx/X56xx (Nehalem-C)
collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
collect limit 1 million 3413 / 3768 0.3 3255.0 1.0X
collect limit 2 millions 9728 / 10440 0.1 9277.3 0.4X
With this patch:
model name : Westmere E56xx/L56xx/X56xx (Nehalem-C)
collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
collect limit 1 million 833 / 1284 1.3 794.4 1.0X
collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X
Author: Liang-Chi Hsieh <simonh@tw.ibm.com>
Closes #11759 from viirya/execute-take.
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala | 78 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala | 21 |
2 files changed, 71 insertions, 28 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index a392b53412..010ed7f500 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -219,48 +219,62 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } /** - * Runs this query returning the result as an array. + * Packing the UnsafeRows into byte array for faster serialization. + * The byte arrays are in the following format: + * [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] + * + * UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also + * compressed. */ - def executeCollect(): Array[InternalRow] = { - // Packing the UnsafeRows into byte array for faster serialization. - // The byte arrays are in the following format: - // [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] - // - // UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also - // compressed. - val byteArrayRdd = execute().mapPartitionsInternal { iter => + private def getByteArrayRdd(n: Int = -1): RDD[Array[Byte]] = { + execute().mapPartitionsInternal { iter => + var count = 0 val buffer = new Array[Byte](4 << 10) // 4K val codec = CompressionCodec.createCodec(SparkEnv.get.conf) val bos = new ByteArrayOutputStream() val out = new DataOutputStream(codec.compressedOutputStream(bos)) - while (iter.hasNext) { + while (iter.hasNext && (n < 0 || count < n)) { val row = iter.next().asInstanceOf[UnsafeRow] out.writeInt(row.getSizeInBytes) row.writeToStream(out, buffer) + count += 1 } out.writeInt(-1) out.flush() out.close() Iterator(bos.toByteArray) } + } - // Collect the byte arrays back to driver, then decode them as UnsafeRows. + /** + * Decode the byte arrays back to UnsafeRows and put them into buffer. + */ + private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = { val nFields = schema.length - val results = ArrayBuffer[InternalRow]() + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(codec.compressedInputStream(bis)) + var sizeOfNextRow = ins.readInt() + while (sizeOfNextRow >= 0) { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(nFields) + row.pointTo(bs, sizeOfNextRow) + buffer += row + sizeOfNextRow = ins.readInt() + } + } + + /** + * Runs this query returning the result as an array. + */ + def executeCollect(): Array[InternalRow] = { + val byteArrayRdd = getByteArrayRdd() + + val results = ArrayBuffer[InternalRow]() byteArrayRdd.collect().foreach { bytes => - val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - val bis = new ByteArrayInputStream(bytes) - val ins = new DataInputStream(codec.compressedInputStream(bis)) - var sizeOfNextRow = ins.readInt() - while (sizeOfNextRow >= 0) { - val bs = new Array[Byte](sizeOfNextRow) - ins.readFully(bs) - val row = new UnsafeRow(nFields) - row.pointTo(bs, sizeOfNextRow) - results += row - sizeOfNextRow = ins.readInt() - } + decodeUnsafeRows(bytes, results) } results.toArray } @@ -283,7 +297,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ return new Array[InternalRow](0) } - val childRDD = execute().map(_.copy()) + val childRDD = getByteArrayRdd(n) val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length @@ -307,13 +321,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val left = n - buf.size val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val sc = sqlContext.sparkContext - val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) + val res = sc.runJob(childRDD, + (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p) + + res.foreach { r => + decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf) + } - res.foreach(buf ++= _.take(n - buf.size)) partsScanned += p.size } - buf.toArray + if (buf.size > n) { + buf.take(n).toArray + } else { + buf.toArray + } } private[this] def isTesting: Boolean = sys.props.contains("spark.testing") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index b6051b07c8..d293ff66fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -465,4 +465,25 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { collect 4 millions 3193 / 3895 0.3 3044.7 0.1X */ } + + ignore("collect limit") { + val N = 1 << 20 + + val benchmark = new Benchmark("collect limit", N) + benchmark.addCase("collect limit 1 million") { iter => + sqlContext.range(N * 4).limit(N).collect() + } + benchmark.addCase("collect limit 2 millions") { iter => + sqlContext.range(N * 4).limit(N * 2).collect() + } + benchmark.run() + + /** + model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) + collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + collect limit 1 million 833 / 1284 1.3 794.4 1.0X + collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X + */ + } } |