diff options
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala | 46 |
1 files changed, 33 insertions, 13 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 010ed7f500..4091f65aec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -84,8 +84,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty /** - * Reset all the metrics. - */ + * Reset all the metrics. + */ private[sql] def resetMetrics(): Unit = { metrics.valuesIterator.foreach(_.reset()) } @@ -249,20 +249,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Decode the byte arrays back to UnsafeRows and put them into buffer. */ - private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = { + private def decodeUnsafeRows(bytes: Array[Byte]): Iterator[InternalRow] = { val nFields = schema.length val codec = CompressionCodec.createCodec(SparkEnv.get.conf) val bis = new ByteArrayInputStream(bytes) val ins = new DataInputStream(codec.compressedInputStream(bis)) - var sizeOfNextRow = ins.readInt() - while (sizeOfNextRow >= 0) { - val bs = new Array[Byte](sizeOfNextRow) - ins.readFully(bs) - val row = new UnsafeRow(nFields) - row.pointTo(bs, sizeOfNextRow) - buffer += row - sizeOfNextRow = ins.readInt() + + new Iterator[InternalRow] { + private var sizeOfNextRow = ins.readInt() + override def hasNext: Boolean = sizeOfNextRow >= 0 + override def next(): InternalRow = { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(nFields) + row.pointTo(bs, sizeOfNextRow) + sizeOfNextRow = ins.readInt() + row + } } } @@ -274,12 +278,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val results = ArrayBuffer[InternalRow]() byteArrayRdd.collect().foreach { bytes => - decodeUnsafeRows(bytes, results) + decodeUnsafeRows(bytes).foreach(results.+=) } results.toArray } /** + * Runs this query returning the result as an iterator of InternalRow. + * + * Note: this will trigger multiple jobs (one for each partition). + */ + def executeToIterator(): Iterator[InternalRow] = { + getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows) + } + + /** * Runs this query returning the result as an array, using external Row format. */ def executeCollectPublic(): Array[Row] = { @@ -325,7 +338,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p) res.foreach { r => - decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf) + decodeUnsafeRows(r.asInstanceOf[Array[Byte]]).foreach(buf.+=) } partsScanned += p.size @@ -379,6 +392,13 @@ private[sql] trait LeafNode extends SparkPlan { override def producedAttributes: AttributeSet = outputSet } +object UnaryNode { + def unapply(a: Any): Option[(SparkPlan, SparkPlan)] = a match { + case s: SparkPlan if s.children.size == 1 => Some((s, s.children.head)) + case _ => None + } +} + private[sql] trait UnaryNode extends SparkPlan { def child: SparkPlan |