aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala4
-rw-r--r--python/pyspark/rdd.py8
-rw-r--r--python/pyspark/sql/dataframe.py14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala35
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala1
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala2
8 files changed, 83 insertions, 16 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 6faa03c12b..4bca16a234 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -453,6 +453,10 @@ private[spark] object PythonRDD extends Logging {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}
+ def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = {
+ serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
+ }
+
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 37574cea0b..cd1f64e8aa 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2299,14 +2299,14 @@ class RDD(object):
"""
Return an iterator that contains all of the elements in this RDD.
The iterator will consume as much memory as the largest partition in this RDD.
+
>>> rdd = sc.parallelize(range(10))
>>> [x for x in rdd.toLocalIterator()]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
- for partition in range(self.getNumPartitions()):
- rows = self.context.runJob(self, lambda x: x, [partition])
- for row in rows:
- yield row
+ with SCCallSiteSync(self.context) as css:
+ port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
+ return _load_from_socket(port, self._jrdd_deserializer)
def _prepare_for_python_RDD(sc, command):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 7a69c4c70c..d473d6b534 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -241,6 +241,20 @@ class DataFrame(object):
return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
@ignore_unicode_prefix
+ @since(2.0)
+ def toLocalIterator(self):
+ """
+ Returns an iterator that contains all of the rows in this :class:`DataFrame`.
+ The iterator will consume as much memory as the largest partition in this DataFrame.
+
+ >>> list(df.toLocalIterator())
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+ """
+ with SCCallSiteSync(self._sc) as css:
+ port = self._jdf.toPythonIterator()
+ return _load_from_socket(port, BatchedSerializer(PickleSerializer()))
+
+ @ignore_unicode_prefix
@since(1.3)
def limit(self, num):
"""Limits the result count to the number specified.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a39a2113e5..8dfe8ff702 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.io.CharArrayWriter
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@@ -2057,6 +2058,24 @@ class Dataset[T] private[sql](
}
/**
+ * Return an iterator that contains all of [[Row]]s in this [[Dataset]].
+ *
+ * The iterator will consume as much memory as the largest partition in this [[Dataset]].
+ *
+ * Note: this results in multiple Spark jobs, and if the input Dataset is the result
+ * of a wide transformation (e.g. join with different partitioners), to avoid
+ * recomputing the input Dataset should be cached first.
+ *
+ * @group action
+ * @since 2.0.0
+ */
+ def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ =>
+ withNewExecutionId {
+ queryExecution.executedPlan.executeToIterator().map(boundTEncoder.fromRow).asJava
+ }
+ }
+
+ /**
* Returns the number of rows in the [[Dataset]].
* @group action
* @since 1.6.0
@@ -2300,6 +2319,12 @@ class Dataset[T] private[sql](
}
}
+ protected[sql] def toPythonIterator(): Int = {
+ withNewExecutionId {
+ PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
+ }
+ }
+
////////////////////////////////////////////////////////////////////////////
// Private Helpers
////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index ff19d1be1c..4091f65aec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -249,20 +249,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Decode the byte arrays back to UnsafeRows and put them into buffer.
*/
- private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = {
+ private def decodeUnsafeRows(bytes: Array[Byte]): Iterator[InternalRow] = {
val nFields = schema.length
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val bis = new ByteArrayInputStream(bytes)
val ins = new DataInputStream(codec.compressedInputStream(bis))
- var sizeOfNextRow = ins.readInt()
- while (sizeOfNextRow >= 0) {
- val bs = new Array[Byte](sizeOfNextRow)
- ins.readFully(bs)
- val row = new UnsafeRow(nFields)
- row.pointTo(bs, sizeOfNextRow)
- buffer += row
- sizeOfNextRow = ins.readInt()
+
+ new Iterator[InternalRow] {
+ private var sizeOfNextRow = ins.readInt()
+ override def hasNext: Boolean = sizeOfNextRow >= 0
+ override def next(): InternalRow = {
+ val bs = new Array[Byte](sizeOfNextRow)
+ ins.readFully(bs)
+ val row = new UnsafeRow(nFields)
+ row.pointTo(bs, sizeOfNextRow)
+ sizeOfNextRow = ins.readInt()
+ row
+ }
}
}
@@ -274,12 +278,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
val results = ArrayBuffer[InternalRow]()
byteArrayRdd.collect().foreach { bytes =>
- decodeUnsafeRows(bytes, results)
+ decodeUnsafeRows(bytes).foreach(results.+=)
}
results.toArray
}
/**
+ * Runs this query returning the result as an iterator of InternalRow.
+ *
+ * Note: this will trigger multiple jobs (one for each partition).
+ */
+ def executeToIterator(): Iterator[InternalRow] = {
+ getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows)
+ }
+
+ /**
* Runs this query returning the result as an array, using external Row format.
*/
def executeCollectPublic(): Array[Row] = {
@@ -325,7 +338,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
(it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p)
res.foreach { r =>
- decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf)
+ decodeUnsafeRows(r.asInstanceOf[Array[Byte]]).foreach(buf.+=)
}
partsScanned += p.size
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 873f681bdf..f26c57b301 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -86,6 +86,16 @@ public class JavaDatasetSuite implements Serializable {
}
@Test
+ public void testToLocalIterator() {
+ List<String> data = Arrays.asList("hello", "world");
+ Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Iterator<String> iter = ds.toLocalIterator();
+ Assert.assertEquals("hello", iter.next());
+ Assert.assertEquals("world", iter.next());
+ Assert.assertFalse(iter.hasNext());
+ }
+
+ @Test
public void testCommonOperation() {
List<String> data = Arrays.asList("hello", "world");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 553bc436a6..2aa90568c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -71,6 +71,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.first() == item)
assert(ds.take(1).head == item)
assert(ds.takeAsList(1).get(0) == item)
+ assert(ds.toLocalIterator().next() === item)
}
test("coalesce, repartition") {
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
index a955314ba3..673a293ce2 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
@@ -222,7 +222,7 @@ private[hive] class SparkExecuteStatementOperation(
val useIncrementalCollect =
hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
if (useIncrementalCollect) {
- result.rdd.toLocalIterator
+ result.toLocalIterator.asScala
} else {
result.collect().iterator
}