aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-08-07 20:04:17 -0700
committerReynold Xin <rxin@databricks.com>2015-08-07 20:04:17 -0700
commitc564b27447ed99e55b359b3df1d586d5766b85ea (patch)
treec2a1e50942d036cc5ceb0f118be6710230a2136f /sql
parent998f4ff94df1d9db1c9e32c04091017c25cd4e81 (diff)
downloadspark-c564b27447ed99e55b359b3df1d586d5766b85ea.tar.gz
spark-c564b27447ed99e55b359b3df1d586d5766b85ea.tar.bz2
spark-c564b27447ed99e55b359b3df1d586d5766b85ea.zip
[SPARK-9753] [SQL] TungstenAggregate should also accept InternalRow instead of just UnsafeRow
https://issues.apache.org/jira/browse/SPARK-9753 This PR makes TungstenAggregate to accept `InternalRow` instead of just `UnsafeRow`. Also, it adds an `getAggregationBufferFromUnsafeRow` method to `UnsafeFixedWidthAggregationMap`. It is useful when we already have grouping keys stored in `UnsafeRow`s. Finally, it wraps `InputStream` and `OutputStream` in `UnsafeRowSerializer` with `BufferedInputStream` and `BufferedOutputStream`, respectively. Author: Yin Huai <yhuai@databricks.com> Closes #8041 from yhuai/joinedRowForProjection and squashes the following commits: 7753e34 [Yin Huai] Use BufferedInputStream and BufferedOutputStream. d68b74e [Yin Huai] Use joinedRow instead of UnsafeRowJoiner. e93c009 [Yin Huai] Add getAggregationBufferFromUnsafeRow for cases that the given groupingKeyRow is already an UnsafeRow.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala51
4 files changed, 39 insertions, 50 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index b08a4a13a2..00218f2130 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -121,6 +121,10 @@ public final class UnsafeFixedWidthAggregationMap {
public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey);
+ return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow);
+ }
+
+ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) {
// Probe our map using the serialized key
final BytesToBytesMap.Location loc = map.lookup(
unsafeGroupingKeyRow.getBaseObject(),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 39f8f992a9..6c7e5cacc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -58,27 +58,14 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
*/
override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
- // When `out` is backed by ChainedBufferOutputStream, we will get an
- // UnsupportedOperationException when we call dOut.writeInt because it internally calls
- // ChainedBufferOutputStream's write(b: Int), which is not supported.
- // To workaround this issue, we create an array for sorting the int value.
- // To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and
- // run SparkSqlSerializer2SortMergeShuffleSuite.
- private[this] var intBuffer: Array[Byte] = new Array[Byte](4)
- private[this] val dOut: DataOutputStream = new DataOutputStream(out)
+ private[this] val dOut: DataOutputStream =
+ new DataOutputStream(new BufferedOutputStream(out))
override def writeValue[T: ClassTag](value: T): SerializationStream = {
val row = value.asInstanceOf[UnsafeRow]
- val size = row.getSizeInBytes
- // This part is based on DataOutputStream's writeInt.
- // It is for dOut.writeInt(row.getSizeInBytes).
- intBuffer(0) = ((size >>> 24) & 0xFF).toByte
- intBuffer(1) = ((size >>> 16) & 0xFF).toByte
- intBuffer(2) = ((size >>> 8) & 0xFF).toByte
- intBuffer(3) = ((size >>> 0) & 0xFF).toByte
- dOut.write(intBuffer, 0, 4)
-
- row.writeToStream(out, writeBuffer)
+
+ dOut.writeInt(row.getSizeInBytes)
+ row.writeToStream(dOut, writeBuffer)
this
}
@@ -105,7 +92,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
override def close(): Unit = {
writeBuffer = null
- intBuffer = null
dOut.writeInt(EOF)
dOut.close()
}
@@ -113,7 +99,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
override def deserializeStream(in: InputStream): DeserializationStream = {
new DeserializationStream {
- private[this] val dIn: DataInputStream = new DataInputStream(in)
+ private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in))
// 1024 is a default buffer size; this buffer will grow to accommodate larger rows
private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
private[this] var row: UnsafeRow = new UnsafeRow()
@@ -129,7 +115,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
if (rowBuffer.length < rowSize) {
rowBuffer = new Array[Byte](rowSize)
}
- ByteStreams.readFully(in, rowBuffer, 0, rowSize)
+ ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize)
rowSize = dIn.readInt() // read the next row's size
if (rowSize == EOF) { // We are returning the last row in this stream
@@ -163,7 +149,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
if (rowBuffer.length < rowSize) {
rowBuffer = new Array[Byte](rowSize)
}
- ByteStreams.readFully(in, rowBuffer, 0, rowSize)
+ ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize)
row.asInstanceOf[T]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index c3dcbd2b71..1694794a53 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -39,7 +39,7 @@ case class TungstenAggregate(
override def canProcessUnsafeRows: Boolean = true
- override def canProcessSafeRows: Boolean = false
+ override def canProcessSafeRows: Boolean = true
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
@@ -77,7 +77,7 @@ case class TungstenAggregate(
resultExpressions,
newMutableProjection,
child.output,
- iter.asInstanceOf[Iterator[UnsafeRow]],
+ iter,
testFallbackStartsAt)
if (!hasInput && groupingExpressions.isEmpty) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 440bef32f4..32160906c3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -22,6 +22,7 @@ import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.types.StructType
@@ -46,8 +47,7 @@ import org.apache.spark.sql.types.StructType
* processing input rows from inputIter, and generating output
* rows.
* - Part 3: Methods and fields used by hash-based aggregation.
- * - Part 4: The function used to switch this iterator from hash-based
- * aggregation to sort-based aggregation.
+ * - Part 4: Methods and fields used when we switch to sort-based aggregation.
* - Part 5: Methods and fields used by sort-based aggregation.
* - Part 6: Loads input and process input rows.
* - Part 7: Public methods of this iterator.
@@ -82,7 +82,7 @@ class TungstenAggregationIterator(
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
- inputIter: Iterator[UnsafeRow],
+ inputIter: Iterator[InternalRow],
testFallbackStartsAt: Option[Int])
extends Iterator[UnsafeRow] with Logging {
@@ -174,13 +174,10 @@ class TungstenAggregationIterator(
// Creates a function used to process a row based on the given inputAttributes.
private def generateProcessRow(
- inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = {
+ inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = {
val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
- val aggregationBufferSchema = StructType.fromAttributes(aggregationBufferAttributes)
- val inputSchema = StructType.fromAttributes(inputAttributes)
- val unsafeRowJoiner =
- GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema)
+ val joinedRow = new JoinedRow()
aggregationMode match {
// Partial-only
@@ -189,9 +186,9 @@ class TungstenAggregationIterator(
val algebraicUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
- (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+ (currentBuffer: UnsafeRow, row: InternalRow) => {
algebraicUpdateProjection.target(currentBuffer)
- algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row))
+ algebraicUpdateProjection(joinedRow(currentBuffer, row))
}
// PartialMerge-only or Final-only
@@ -203,10 +200,10 @@ class TungstenAggregationIterator(
mergeExpressions,
aggregationBufferAttributes ++ inputAttributes)()
- (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+ (currentBuffer: UnsafeRow, row: InternalRow) => {
// Process all algebraic aggregate functions.
algebraicMergeProjection.target(currentBuffer)
- algebraicMergeProjection(unsafeRowJoiner.join(currentBuffer, row))
+ algebraicMergeProjection(joinedRow(currentBuffer, row))
}
// Final-Complete
@@ -233,8 +230,8 @@ class TungstenAggregationIterator(
val completeAlgebraicUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
- (currentBuffer: UnsafeRow, row: UnsafeRow) => {
- val input = unsafeRowJoiner.join(currentBuffer, row)
+ (currentBuffer: UnsafeRow, row: InternalRow) => {
+ val input = joinedRow(currentBuffer, row)
// For all aggregate functions with mode Complete, update the given currentBuffer.
completeAlgebraicUpdateProjection.target(currentBuffer)(input)
@@ -253,14 +250,14 @@ class TungstenAggregationIterator(
val completeAlgebraicUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
- (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+ (currentBuffer: UnsafeRow, row: InternalRow) => {
completeAlgebraicUpdateProjection.target(currentBuffer)
// For all aggregate functions with mode Complete, update the given currentBuffer.
- completeAlgebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row))
+ completeAlgebraicUpdateProjection(joinedRow(currentBuffer, row))
}
// Grouping only.
- case (None, None) => (currentBuffer: UnsafeRow, row: UnsafeRow) => {}
+ case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {}
case other =>
throw new IllegalStateException(
@@ -272,15 +269,16 @@ class TungstenAggregationIterator(
private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = {
val groupingAttributes = groupingExpressions.map(_.toAttribute)
- val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
- val bufferSchema = StructType.fromAttributes(bufferAttributes)
- val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
aggregationMode match {
// Partial-only or PartialMerge-only: every output row is basically the values of
// the grouping expressions and the corresponding aggregation buffer.
case (Some(Partial), None) | (Some(PartialMerge), None) =>
+ val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
+ val bufferSchema = StructType.fromAttributes(bufferAttributes)
+ val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
+
(currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
unsafeRowJoiner.join(currentGroupingKey, currentBuffer)
}
@@ -288,11 +286,12 @@ class TungstenAggregationIterator(
// Final-only, Complete-only and Final-Complete: a output row is generated based on
// resultExpressions.
case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
+ val joinedRow = new JoinedRow()
val resultProjection =
UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes)
(currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
- resultProjection(unsafeRowJoiner.join(currentGroupingKey, currentBuffer))
+ resultProjection(joinedRow(currentGroupingKey, currentBuffer))
}
// Grouping-only: a output row is generated from values of grouping expressions.
@@ -316,7 +315,7 @@ class TungstenAggregationIterator(
// A function used to process a input row. Its first argument is the aggregation buffer
// and the second argument is the input row.
- private[this] var processRow: (UnsafeRow, UnsafeRow) => Unit =
+ private[this] var processRow: (UnsafeRow, InternalRow) => Unit =
generateProcessRow(originalInputAttributes)
// A function used to generate output rows based on the grouping keys (first argument)
@@ -354,7 +353,7 @@ class TungstenAggregationIterator(
while (!sortBased && inputIter.hasNext) {
val newInput = inputIter.next()
val groupingKey = groupProjection.apply(newInput)
- val buffer: UnsafeRow = hashMap.getAggregationBuffer(groupingKey)
+ val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
if (buffer == null) {
// buffer == null means that we could not allocate more memory.
// Now, we need to spill the map and switch to sort-based aggregation.
@@ -374,7 +373,7 @@ class TungstenAggregationIterator(
val newInput = inputIter.next()
val groupingKey = groupProjection.apply(newInput)
val buffer: UnsafeRow = if (i < fallbackStartsAt) {
- hashMap.getAggregationBuffer(groupingKey)
+ hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
} else {
null
}
@@ -397,7 +396,7 @@ class TungstenAggregationIterator(
private[this] var mapIteratorHasNext: Boolean = false
///////////////////////////////////////////////////////////////////////////
- // Part 3: Methods and fields used by sort-based aggregation.
+ // Part 4: Methods and fields used when we switch to sort-based aggregation.
///////////////////////////////////////////////////////////////////////////
// This sorter is used for sort-based aggregation. It is initialized as soon as
@@ -407,7 +406,7 @@ class TungstenAggregationIterator(
/**
* Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
*/
- private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = {
+ private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = {
logInfo("falling back to sort based aggregation.")
// Step 1: Get the ExternalSorter containing sorted entries of the map.
externalSorter = hashMap.destructAndCreateExternalSorter()