aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-04-04 13:31:44 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-04 13:31:44 -0700
commitcc70f174169f45c85d459126a68bbe43c0bec328 (patch)
treea455d9fab09fb11f6d78265243a1a301a6e5563b
parent7143904700435265975d36f073cce2833467e121 (diff)
downloadspark-cc70f174169f45c85d459126a68bbe43c0bec328.tar.gz
spark-cc70f174169f45c85d459126a68bbe43c0bec328.tar.bz2
spark-cc70f174169f45c85d459126a68bbe43c0bec328.zip
[SPARK-14334] [SQL] add toLocalIterator for Dataset/DataFrame
## What changes were proposed in this pull request? RDD.toLocalIterator() could be used to fetch one partition at a time to reduce the memory usage. Right now, for Dataset/Dataframe we have to use df.rdd.toLocalIterator, which is super slow also requires lots of memory (because of the Java serializer or even Kyro serializer). This PR introduce an optimized toLocalIterator for Dataset/DataFrame, which is much faster and requires much less memory. For a partition with 5 millions rows, `df.rdd.toIterator` took about 100 seconds, but df.toIterator took less than 7 seconds. For 10 millions row, rdd.toIterator will crash (not enough memory) with 4G heap, but df.toLocalIterator could finished in 12 seconds. The JDBC server has been updated to use DataFrame.toIterator. ## How was this patch tested? Existing tests. Author: Davies Liu <davies@databricks.com> Closes #12114 from davies/local_iterator.
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala4
-rw-r--r--python/pyspark/rdd.py8
-rw-r--r--python/pyspark/sql/dataframe.py14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala35
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala1
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala2
8 files changed, 83 insertions, 16 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 6faa03c12b..4bca16a234 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -453,6 +453,10 @@ private[spark] object PythonRDD extends Logging {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}
+ def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = {
+ serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
+ }
+
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 37574cea0b..cd1f64e8aa 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2299,14 +2299,14 @@ class RDD(object):
"""
Return an iterator that contains all of the elements in this RDD.
The iterator will consume as much memory as the largest partition in this RDD.
+
>>> rdd = sc.parallelize(range(10))
>>> [x for x in rdd.toLocalIterator()]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
- for partition in range(self.getNumPartitions()):
- rows = self.context.runJob(self, lambda x: x, [partition])
- for row in rows:
- yield row
+ with SCCallSiteSync(self.context) as css:
+ port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
+ return _load_from_socket(port, self._jrdd_deserializer)
def _prepare_for_python_RDD(sc, command):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 7a69c4c70c..d473d6b534 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -241,6 +241,20 @@ class DataFrame(object):
return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
@ignore_unicode_prefix
+ @since(2.0)
+ def toLocalIterator(self):
+ """
+ Returns an iterator that contains all of the rows in this :class:`DataFrame`.
+ The iterator will consume as much memory as the largest partition in this DataFrame.
+
+ >>> list(df.toLocalIterator())
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+ """
+ with SCCallSiteSync(self._sc) as css:
+ port = self._jdf.toPythonIterator()
+ return _load_from_socket(port, BatchedSerializer(PickleSerializer()))
+
+ @ignore_unicode_prefix
@since(1.3)
def limit(self, num):
"""Limits the result count to the number specified.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a39a2113e5..8dfe8ff702 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.io.CharArrayWriter
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@@ -2057,6 +2058,24 @@ class Dataset[T] private[sql](
}
/**
+ * Return an iterator that contains all of [[Row]]s in this [[Dataset]].
+ *
+ * The iterator will consume as much memory as the largest partition in this [[Dataset]].
+ *
+ * Note: this results in multiple Spark jobs, and if the input Dataset is the result
+ * of a wide transformation (e.g. join with different partitioners), to avoid
+ * recomputing the input Dataset should be cached first.
+ *
+ * @group action
+ * @since 2.0.0
+ */
+ def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ =>
+ withNewExecutionId {
+ queryExecution.executedPlan.executeToIterator().map(boundTEncoder.fromRow).asJava
+ }
+ }
+
+ /**
* Returns the number of rows in the [[Dataset]].
* @group action
* @since 1.6.0
@@ -2300,6 +2319,12 @@ class Dataset[T] private[sql](
}
}
+ protected[sql] def toPythonIterator(): Int = {
+ withNewExecutionId {
+ PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
+ }
+ }
+
////////////////////////////////////////////////////////////////////////////
// Private Helpers
////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index ff19d1be1c..4091f65aec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -249,20 +249,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Decode the byte arrays back to UnsafeRows and put them into buffer.
*/
- private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = {
+ private def decodeUnsafeRows(bytes: Array[Byte]): Iterator[InternalRow] = {
val nFields = schema.length
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val bis = new ByteArrayInputStream(bytes)
val ins = new DataInputStream(codec.compressedInputStream(bis))
- var sizeOfNextRow = ins.readInt()
- while (sizeOfNextRow >= 0) {
- val bs = new Array[Byte](sizeOfNextRow)
- ins.readFully(bs)
- val row = new UnsafeRow(nFields)
- row.pointTo(bs, sizeOfNextRow)
- buffer += row
- sizeOfNextRow = ins.readInt()
+
+ new Iterator[InternalRow] {
+ private var sizeOfNextRow = ins.readInt()
+ override def hasNext: Boolean = sizeOfNextRow >= 0
+ override def next(): InternalRow = {
+ val bs = new Array[Byte](sizeOfNextRow)
+ ins.readFully(bs)
+ val row = new UnsafeRow(nFields)
+ row.pointTo(bs, sizeOfNextRow)
+ sizeOfNextRow = ins.readInt()
+ row
+ }
}
}
@@ -274,12 +278,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
val results = ArrayBuffer[InternalRow]()
byteArrayRdd.collect().foreach { bytes =>
- decodeUnsafeRows(bytes, results)
+ decodeUnsafeRows(bytes).foreach(results.+=)
}
results.toArray
}
/**
+ * Runs this query returning the result as an iterator of InternalRow.
+ *
+ * Note: this will trigger multiple jobs (one for each partition).
+ */
+ def executeToIterator(): Iterator[InternalRow] = {
+ getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows)
+ }
+
+ /**
* Runs this query returning the result as an array, using external Row format.
*/
def executeCollectPublic(): Array[Row] = {
@@ -325,7 +338,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
(it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p)
res.foreach { r =>
- decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf)
+ decodeUnsafeRows(r.asInstanceOf[Array[Byte]]).foreach(buf.+=)
}
partsScanned += p.size
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 873f681bdf..f26c57b301 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -86,6 +86,16 @@ public class JavaDatasetSuite implements Serializable {
}
@Test
+ public void testToLocalIterator() {
+ List<String> data = Arrays.asList("hello", "world");
+ Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Iterator<String> iter = ds.toLocalIterator();
+ Assert.assertEquals("hello", iter.next());
+ Assert.assertEquals("world", iter.next());
+ Assert.assertFalse(iter.hasNext());
+ }
+
+ @Test
public void testCommonOperation() {
List<String> data = Arrays.asList("hello", "world");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 553bc436a6..2aa90568c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -71,6 +71,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.first() == item)
assert(ds.take(1).head == item)
assert(ds.takeAsList(1).get(0) == item)
+ assert(ds.toLocalIterator().next() === item)
}
test("coalesce, repartition") {
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
index a955314ba3..673a293ce2 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
@@ -222,7 +222,7 @@ private[hive] class SparkExecuteStatementOperation(
val useIncrementalCollect =
hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
if (useIncrementalCollect) {
- result.rdd.toLocalIterator
+ result.toLocalIterator.asScala
} else {
result.collect().iterator
}