aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala35
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala1
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala2
5 files changed, 61 insertions, 12 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a39a2113e5..8dfe8ff702 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.io.CharArrayWriter
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@@ -2057,6 +2058,24 @@ class Dataset[T] private[sql](
}
/**
+ * Return an iterator that contains all of [[Row]]s in this [[Dataset]].
+ *
+ * The iterator will consume as much memory as the largest partition in this [[Dataset]].
+ *
+ * Note: this results in multiple Spark jobs, and if the input Dataset is the result
+ * of a wide transformation (e.g. join with different partitioners), to avoid
+ * recomputing the input Dataset should be cached first.
+ *
+ * @group action
+ * @since 2.0.0
+ */
+ def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ =>
+ withNewExecutionId {
+ queryExecution.executedPlan.executeToIterator().map(boundTEncoder.fromRow).asJava
+ }
+ }
+
+ /**
* Returns the number of rows in the [[Dataset]].
* @group action
* @since 1.6.0
@@ -2300,6 +2319,12 @@ class Dataset[T] private[sql](
}
}
+ protected[sql] def toPythonIterator(): Int = {
+ withNewExecutionId {
+ PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
+ }
+ }
+
////////////////////////////////////////////////////////////////////////////
// Private Helpers
////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index ff19d1be1c..4091f65aec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -249,20 +249,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Decode the byte arrays back to UnsafeRows and put them into buffer.
*/
- private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = {
+ private def decodeUnsafeRows(bytes: Array[Byte]): Iterator[InternalRow] = {
val nFields = schema.length
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val bis = new ByteArrayInputStream(bytes)
val ins = new DataInputStream(codec.compressedInputStream(bis))
- var sizeOfNextRow = ins.readInt()
- while (sizeOfNextRow >= 0) {
- val bs = new Array[Byte](sizeOfNextRow)
- ins.readFully(bs)
- val row = new UnsafeRow(nFields)
- row.pointTo(bs, sizeOfNextRow)
- buffer += row
- sizeOfNextRow = ins.readInt()
+
+ new Iterator[InternalRow] {
+ private var sizeOfNextRow = ins.readInt()
+ override def hasNext: Boolean = sizeOfNextRow >= 0
+ override def next(): InternalRow = {
+ val bs = new Array[Byte](sizeOfNextRow)
+ ins.readFully(bs)
+ val row = new UnsafeRow(nFields)
+ row.pointTo(bs, sizeOfNextRow)
+ sizeOfNextRow = ins.readInt()
+ row
+ }
}
}
@@ -274,12 +278,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
val results = ArrayBuffer[InternalRow]()
byteArrayRdd.collect().foreach { bytes =>
- decodeUnsafeRows(bytes, results)
+ decodeUnsafeRows(bytes).foreach(results.+=)
}
results.toArray
}
/**
+ * Runs this query returning the result as an iterator of InternalRow.
+ *
+ * Note: this will trigger multiple jobs (one for each partition).
+ */
+ def executeToIterator(): Iterator[InternalRow] = {
+ getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows)
+ }
+
+ /**
* Runs this query returning the result as an array, using external Row format.
*/
def executeCollectPublic(): Array[Row] = {
@@ -325,7 +338,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
(it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p)
res.foreach { r =>
- decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf)
+ decodeUnsafeRows(r.asInstanceOf[Array[Byte]]).foreach(buf.+=)
}
partsScanned += p.size
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 873f681bdf..f26c57b301 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -86,6 +86,16 @@ public class JavaDatasetSuite implements Serializable {
}
@Test
+ public void testToLocalIterator() {
+ List<String> data = Arrays.asList("hello", "world");
+ Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Iterator<String> iter = ds.toLocalIterator();
+ Assert.assertEquals("hello", iter.next());
+ Assert.assertEquals("world", iter.next());
+ Assert.assertFalse(iter.hasNext());
+ }
+
+ @Test
public void testCommonOperation() {
List<String> data = Arrays.asList("hello", "world");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 553bc436a6..2aa90568c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -71,6 +71,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.first() == item)
assert(ds.take(1).head == item)
assert(ds.takeAsList(1).get(0) == item)
+ assert(ds.toLocalIterator().next() === item)
}
test("coalesce, repartition") {
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
index a955314ba3..673a293ce2 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
@@ -222,7 +222,7 @@ private[hive] class SparkExecuteStatementOperation(
val useIncrementalCollect =
hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
if (useIncrementalCollect) {
- result.rdd.toLocalIterator
+ result.toLocalIterator.asScala
} else {
result.collect().iterator
}