diff options
8 files changed, 83 insertions, 16 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 6faa03c12b..4bca16a234 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -453,6 +453,10 @@ private[spark] object PythonRDD extends Logging { serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } + def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = { + serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") + } + def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 37574cea0b..cd1f64e8aa 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2299,14 +2299,14 @@ class RDD(object): """ Return an iterator that contains all of the elements in this RDD. The iterator will consume as much memory as the largest partition in this RDD. + >>> rdd = sc.parallelize(range(10)) >>> [x for x in rdd.toLocalIterator()] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ - for partition in range(self.getNumPartitions()): - rows = self.context.runJob(self, lambda x: x, [partition]) - for row in rows: - yield row + with SCCallSiteSync(self.context) as css: + port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) + return _load_from_socket(port, self._jrdd_deserializer) def _prepare_for_python_RDD(sc, command): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7a69c4c70c..d473d6b534 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -241,6 +241,20 @@ class DataFrame(object): return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix + @since(2.0) + def toLocalIterator(self): + """ + Returns an iterator that contains all of the rows in this :class:`DataFrame`. + The iterator will consume as much memory as the largest partition in this DataFrame. + + >>> list(df.toLocalIterator()) + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + with SCCallSiteSync(self._sc) as css: + port = self._jdf.toPythonIterator() + return _load_from_socket(port, BatchedSerializer(PickleSerializer())) + + @ignore_unicode_prefix @since(1.3) def limit(self, num): """Limits the result count to the number specified. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a39a2113e5..8dfe8ff702 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.CharArrayWriter import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -2057,6 +2058,24 @@ class Dataset[T] private[sql]( } /** + * Return an iterator that contains all of [[Row]]s in this [[Dataset]]. + * + * The iterator will consume as much memory as the largest partition in this [[Dataset]]. + * + * Note: this results in multiple Spark jobs, and if the input Dataset is the result + * of a wide transformation (e.g. join with different partitioners), to avoid + * recomputing the input Dataset should be cached first. + * + * @group action + * @since 2.0.0 + */ + def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ => + withNewExecutionId { + queryExecution.executedPlan.executeToIterator().map(boundTEncoder.fromRow).asJava + } + } + + /** * Returns the number of rows in the [[Dataset]]. * @group action * @since 1.6.0 @@ -2300,6 +2319,12 @@ class Dataset[T] private[sql]( } } + protected[sql] def toPythonIterator(): Int = { + withNewExecutionId { + PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) + } + } + //////////////////////////////////////////////////////////////////////////// // Private Helpers //////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ff19d1be1c..4091f65aec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -249,20 +249,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Decode the byte arrays back to UnsafeRows and put them into buffer. */ - private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = { + private def decodeUnsafeRows(bytes: Array[Byte]): Iterator[InternalRow] = { val nFields = schema.length val codec = CompressionCodec.createCodec(SparkEnv.get.conf) val bis = new ByteArrayInputStream(bytes) val ins = new DataInputStream(codec.compressedInputStream(bis)) - var sizeOfNextRow = ins.readInt() - while (sizeOfNextRow >= 0) { - val bs = new Array[Byte](sizeOfNextRow) - ins.readFully(bs) - val row = new UnsafeRow(nFields) - row.pointTo(bs, sizeOfNextRow) - buffer += row - sizeOfNextRow = ins.readInt() + + new Iterator[InternalRow] { + private var sizeOfNextRow = ins.readInt() + override def hasNext: Boolean = sizeOfNextRow >= 0 + override def next(): InternalRow = { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(nFields) + row.pointTo(bs, sizeOfNextRow) + sizeOfNextRow = ins.readInt() + row + } } } @@ -274,12 +278,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val results = ArrayBuffer[InternalRow]() byteArrayRdd.collect().foreach { bytes => - decodeUnsafeRows(bytes, results) + decodeUnsafeRows(bytes).foreach(results.+=) } results.toArray } /** + * Runs this query returning the result as an iterator of InternalRow. + * + * Note: this will trigger multiple jobs (one for each partition). + */ + def executeToIterator(): Iterator[InternalRow] = { + getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows) + } + + /** * Runs this query returning the result as an array, using external Row format. */ def executeCollectPublic(): Array[Row] = { @@ -325,7 +338,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p) res.foreach { r => - decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf) + decodeUnsafeRows(r.asInstanceOf[Array[Byte]]).foreach(buf.+=) } partsScanned += p.size diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 873f681bdf..f26c57b301 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -86,6 +86,16 @@ public class JavaDatasetSuite implements Serializable { } @Test + public void testToLocalIterator() { + List<String> data = Arrays.asList("hello", "world"); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Iterator<String> iter = ds.toLocalIterator(); + Assert.assertEquals("hello", iter.next()); + Assert.assertEquals("world", iter.next()); + Assert.assertFalse(iter.hasNext()); + } + + @Test public void testCommonOperation() { List<String> data = Arrays.asList("hello", "world"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 553bc436a6..2aa90568c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -71,6 +71,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.first() == item) assert(ds.take(1).head == item) assert(ds.takeAsList(1).get(0) == item) + assert(ds.toLocalIterator().next() === item) } test("coalesce, repartition") { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index a955314ba3..673a293ce2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -222,7 +222,7 @@ private[hive] class SparkExecuteStatementOperation( val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - result.rdd.toLocalIterator + result.toLocalIterator.asScala } else { result.collect().iterator } |