aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala46
1 files changed, 33 insertions, 13 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 010ed7f500..4091f65aec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -84,8 +84,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty
/**
- * Reset all the metrics.
- */
+ * Reset all the metrics.
+ */
private[sql] def resetMetrics(): Unit = {
metrics.valuesIterator.foreach(_.reset())
}
@@ -249,20 +249,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Decode the byte arrays back to UnsafeRows and put them into buffer.
*/
- private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = {
+ private def decodeUnsafeRows(bytes: Array[Byte]): Iterator[InternalRow] = {
val nFields = schema.length
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val bis = new ByteArrayInputStream(bytes)
val ins = new DataInputStream(codec.compressedInputStream(bis))
- var sizeOfNextRow = ins.readInt()
- while (sizeOfNextRow >= 0) {
- val bs = new Array[Byte](sizeOfNextRow)
- ins.readFully(bs)
- val row = new UnsafeRow(nFields)
- row.pointTo(bs, sizeOfNextRow)
- buffer += row
- sizeOfNextRow = ins.readInt()
+
+ new Iterator[InternalRow] {
+ private var sizeOfNextRow = ins.readInt()
+ override def hasNext: Boolean = sizeOfNextRow >= 0
+ override def next(): InternalRow = {
+ val bs = new Array[Byte](sizeOfNextRow)
+ ins.readFully(bs)
+ val row = new UnsafeRow(nFields)
+ row.pointTo(bs, sizeOfNextRow)
+ sizeOfNextRow = ins.readInt()
+ row
+ }
}
}
@@ -274,12 +278,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
val results = ArrayBuffer[InternalRow]()
byteArrayRdd.collect().foreach { bytes =>
- decodeUnsafeRows(bytes, results)
+ decodeUnsafeRows(bytes).foreach(results.+=)
}
results.toArray
}
/**
+ * Runs this query returning the result as an iterator of InternalRow.
+ *
+ * Note: this will trigger multiple jobs (one for each partition).
+ */
+ def executeToIterator(): Iterator[InternalRow] = {
+ getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows)
+ }
+
+ /**
* Runs this query returning the result as an array, using external Row format.
*/
def executeCollectPublic(): Array[Row] = {
@@ -325,7 +338,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
(it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p)
res.foreach { r =>
- decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf)
+ decodeUnsafeRows(r.asInstanceOf[Array[Byte]]).foreach(buf.+=)
}
partsScanned += p.size
@@ -379,6 +392,13 @@ private[sql] trait LeafNode extends SparkPlan {
override def producedAttributes: AttributeSet = outputSet
}
+object UnaryNode {
+ def unapply(a: Any): Option[(SparkPlan, SparkPlan)] = a match {
+ case s: SparkPlan if s.children.size == 1 => Some((s, s.children.head))
+ case _ => None
+ }
+}
+
private[sql] trait UnaryNode extends SparkPlan {
def child: SparkPlan