diff options
author | Yin Huai <yhuai@databricks.com> | 2015-08-07 20:04:17 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-08-07 20:04:17 -0700 |
commit | c564b27447ed99e55b359b3df1d586d5766b85ea (patch) | |
tree | c2a1e50942d036cc5ceb0f118be6710230a2136f | |
parent | 998f4ff94df1d9db1c9e32c04091017c25cd4e81 (diff) | |
download | spark-c564b27447ed99e55b359b3df1d586d5766b85ea.tar.gz spark-c564b27447ed99e55b359b3df1d586d5766b85ea.tar.bz2 spark-c564b27447ed99e55b359b3df1d586d5766b85ea.zip |
[SPARK-9753] [SQL] TungstenAggregate should also accept InternalRow instead of just UnsafeRow
https://issues.apache.org/jira/browse/SPARK-9753
This PR makes TungstenAggregate to accept `InternalRow` instead of just `UnsafeRow`. Also, it adds an `getAggregationBufferFromUnsafeRow` method to `UnsafeFixedWidthAggregationMap`. It is useful when we already have grouping keys stored in `UnsafeRow`s. Finally, it wraps `InputStream` and `OutputStream` in `UnsafeRowSerializer` with `BufferedInputStream` and `BufferedOutputStream`, respectively.
Author: Yin Huai <yhuai@databricks.com>
Closes #8041 from yhuai/joinedRowForProjection and squashes the following commits:
7753e34 [Yin Huai] Use BufferedInputStream and BufferedOutputStream.
d68b74e [Yin Huai] Use joinedRow instead of UnsafeRowJoiner.
e93c009 [Yin Huai] Add getAggregationBufferFromUnsafeRow for cases that the given groupingKeyRow is already an UnsafeRow.
4 files changed, 39 insertions, 50 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index b08a4a13a2..00218f2130 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -121,6 +121,10 @@ public final class UnsafeFixedWidthAggregationMap { public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey); + return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow); + } + + public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) { // Probe our map using the serialized key final BytesToBytesMap.Location loc = map.lookup( unsafeGroupingKeyRow.getBaseObject(), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 39f8f992a9..6c7e5cacc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -58,27 +58,14 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst */ override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) - // When `out` is backed by ChainedBufferOutputStream, we will get an - // UnsupportedOperationException when we call dOut.writeInt because it internally calls - // ChainedBufferOutputStream's write(b: Int), which is not supported. - // To workaround this issue, we create an array for sorting the int value. - // To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and - // run SparkSqlSerializer2SortMergeShuffleSuite. - private[this] var intBuffer: Array[Byte] = new Array[Byte](4) - private[this] val dOut: DataOutputStream = new DataOutputStream(out) + private[this] val dOut: DataOutputStream = + new DataOutputStream(new BufferedOutputStream(out)) override def writeValue[T: ClassTag](value: T): SerializationStream = { val row = value.asInstanceOf[UnsafeRow] - val size = row.getSizeInBytes - // This part is based on DataOutputStream's writeInt. - // It is for dOut.writeInt(row.getSizeInBytes). - intBuffer(0) = ((size >>> 24) & 0xFF).toByte - intBuffer(1) = ((size >>> 16) & 0xFF).toByte - intBuffer(2) = ((size >>> 8) & 0xFF).toByte - intBuffer(3) = ((size >>> 0) & 0xFF).toByte - dOut.write(intBuffer, 0, 4) - - row.writeToStream(out, writeBuffer) + + dOut.writeInt(row.getSizeInBytes) + row.writeToStream(dOut, writeBuffer) this } @@ -105,7 +92,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def close(): Unit = { writeBuffer = null - intBuffer = null dOut.writeInt(EOF) dOut.close() } @@ -113,7 +99,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def deserializeStream(in: InputStream): DeserializationStream = { new DeserializationStream { - private[this] val dIn: DataInputStream = new DataInputStream(in) + private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in)) // 1024 is a default buffer size; this buffer will grow to accommodate larger rows private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) private[this] var row: UnsafeRow = new UnsafeRow() @@ -129,7 +115,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst if (rowBuffer.length < rowSize) { rowBuffer = new Array[Byte](rowSize) } - ByteStreams.readFully(in, rowBuffer, 0, rowSize) + ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) rowSize = dIn.readInt() // read the next row's size if (rowSize == EOF) { // We are returning the last row in this stream @@ -163,7 +149,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst if (rowBuffer.length < rowSize) { rowBuffer = new Array[Byte](rowSize) } - ByteStreams.readFully(in, rowBuffer, 0, rowSize) + ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index c3dcbd2b71..1694794a53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -39,7 +39,7 @@ case class TungstenAggregate( override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false + override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -77,7 +77,7 @@ case class TungstenAggregate( resultExpressions, newMutableProjection, child.output, - iter.asInstanceOf[Iterator[UnsafeRow]], + iter, testFallbackStartsAt) if (!hasInput && groupingExpressions.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 440bef32f4..32160906c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -22,6 +22,7 @@ import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.types.StructType @@ -46,8 +47,7 @@ import org.apache.spark.sql.types.StructType * processing input rows from inputIter, and generating output * rows. * - Part 3: Methods and fields used by hash-based aggregation. - * - Part 4: The function used to switch this iterator from hash-based - * aggregation to sort-based aggregation. + * - Part 4: Methods and fields used when we switch to sort-based aggregation. * - Part 5: Methods and fields used by sort-based aggregation. * - Part 6: Loads input and process input rows. * - Part 7: Public methods of this iterator. @@ -82,7 +82,7 @@ class TungstenAggregationIterator( resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], - inputIter: Iterator[UnsafeRow], + inputIter: Iterator[InternalRow], testFallbackStartsAt: Option[Int]) extends Iterator[UnsafeRow] with Logging { @@ -174,13 +174,10 @@ class TungstenAggregationIterator( // Creates a function used to process a row based on the given inputAttributes. private def generateProcessRow( - inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = { + inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = { val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) - val aggregationBufferSchema = StructType.fromAttributes(aggregationBufferAttributes) - val inputSchema = StructType.fromAttributes(inputAttributes) - val unsafeRowJoiner = - GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema) + val joinedRow = new JoinedRow() aggregationMode match { // Partial-only @@ -189,9 +186,9 @@ class TungstenAggregationIterator( val algebraicUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - (currentBuffer: UnsafeRow, row: UnsafeRow) => { + (currentBuffer: UnsafeRow, row: InternalRow) => { algebraicUpdateProjection.target(currentBuffer) - algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) + algebraicUpdateProjection(joinedRow(currentBuffer, row)) } // PartialMerge-only or Final-only @@ -203,10 +200,10 @@ class TungstenAggregationIterator( mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() - (currentBuffer: UnsafeRow, row: UnsafeRow) => { + (currentBuffer: UnsafeRow, row: InternalRow) => { // Process all algebraic aggregate functions. algebraicMergeProjection.target(currentBuffer) - algebraicMergeProjection(unsafeRowJoiner.join(currentBuffer, row)) + algebraicMergeProjection(joinedRow(currentBuffer, row)) } // Final-Complete @@ -233,8 +230,8 @@ class TungstenAggregationIterator( val completeAlgebraicUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - (currentBuffer: UnsafeRow, row: UnsafeRow) => { - val input = unsafeRowJoiner.join(currentBuffer, row) + (currentBuffer: UnsafeRow, row: InternalRow) => { + val input = joinedRow(currentBuffer, row) // For all aggregate functions with mode Complete, update the given currentBuffer. completeAlgebraicUpdateProjection.target(currentBuffer)(input) @@ -253,14 +250,14 @@ class TungstenAggregationIterator( val completeAlgebraicUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - (currentBuffer: UnsafeRow, row: UnsafeRow) => { + (currentBuffer: UnsafeRow, row: InternalRow) => { completeAlgebraicUpdateProjection.target(currentBuffer) // For all aggregate functions with mode Complete, update the given currentBuffer. - completeAlgebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) + completeAlgebraicUpdateProjection(joinedRow(currentBuffer, row)) } // Grouping only. - case (None, None) => (currentBuffer: UnsafeRow, row: UnsafeRow) => {} + case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {} case other => throw new IllegalStateException( @@ -272,15 +269,16 @@ class TungstenAggregationIterator( private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = { val groupingAttributes = groupingExpressions.map(_.toAttribute) - val groupingKeySchema = StructType.fromAttributes(groupingAttributes) val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) - val bufferSchema = StructType.fromAttributes(bufferAttributes) - val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) aggregationMode match { // Partial-only or PartialMerge-only: every output row is basically the values of // the grouping expressions and the corresponding aggregation buffer. case (Some(Partial), None) | (Some(PartialMerge), None) => + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { unsafeRowJoiner.join(currentGroupingKey, currentBuffer) } @@ -288,11 +286,12 @@ class TungstenAggregationIterator( // Final-only, Complete-only and Final-Complete: a output row is generated based on // resultExpressions. case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => + val joinedRow = new JoinedRow() val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes) (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - resultProjection(unsafeRowJoiner.join(currentGroupingKey, currentBuffer)) + resultProjection(joinedRow(currentGroupingKey, currentBuffer)) } // Grouping-only: a output row is generated from values of grouping expressions. @@ -316,7 +315,7 @@ class TungstenAggregationIterator( // A function used to process a input row. Its first argument is the aggregation buffer // and the second argument is the input row. - private[this] var processRow: (UnsafeRow, UnsafeRow) => Unit = + private[this] var processRow: (UnsafeRow, InternalRow) => Unit = generateProcessRow(originalInputAttributes) // A function used to generate output rows based on the grouping keys (first argument) @@ -354,7 +353,7 @@ class TungstenAggregationIterator( while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() val groupingKey = groupProjection.apply(newInput) - val buffer: UnsafeRow = hashMap.getAggregationBuffer(groupingKey) + val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) if (buffer == null) { // buffer == null means that we could not allocate more memory. // Now, we need to spill the map and switch to sort-based aggregation. @@ -374,7 +373,7 @@ class TungstenAggregationIterator( val newInput = inputIter.next() val groupingKey = groupProjection.apply(newInput) val buffer: UnsafeRow = if (i < fallbackStartsAt) { - hashMap.getAggregationBuffer(groupingKey) + hashMap.getAggregationBufferFromUnsafeRow(groupingKey) } else { null } @@ -397,7 +396,7 @@ class TungstenAggregationIterator( private[this] var mapIteratorHasNext: Boolean = false /////////////////////////////////////////////////////////////////////////// - // Part 3: Methods and fields used by sort-based aggregation. + // Part 4: Methods and fields used when we switch to sort-based aggregation. /////////////////////////////////////////////////////////////////////////// // This sorter is used for sort-based aggregation. It is initialized as soon as @@ -407,7 +406,7 @@ class TungstenAggregationIterator( /** * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. */ - private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = { + private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = { logInfo("falling back to sort based aggregation.") // Step 1: Get the ExternalSorter containing sorted entries of the map. externalSorter = hashMap.destructAndCreateExternalSorter() |