aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-11-17 21:40:58 -0800
committerReynold Xin <rxin@databricks.com>2015-11-17 21:40:58 -0800
commit91f4b6f2db12650dfc33a576803ba8aeccf935dd (patch)
tree6767fc52ef951f8930c9cf9c74b8b7e3f1a7f26e /sql
parent98be8169f07eb0f1b8f01776c71d0e1ed3d5e4d5 (diff)
downloadspark-91f4b6f2db12650dfc33a576803ba8aeccf935dd.tar.gz
spark-91f4b6f2db12650dfc33a576803ba8aeccf935dd.tar.bz2
spark-91f4b6f2db12650dfc33a576803ba8aeccf935dd.zip
[SPARK-11797][SQL] collect, first, and take should use encoders for serialization
They were previously using Spark's default serializer for serialization. Author: Reynold Xin <rxin@databricks.com> Closes #9787 from rxin/SPARK-11797.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala30
2 files changed, 41 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index bd01dd4dc5..718ed812dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -22,6 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.function._
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
@@ -199,7 +200,6 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
- encoderFor[T].assertUnresolved()
new Dataset[U](
sqlContext,
MapPartitions[T, U](
@@ -519,7 +519,7 @@ class Dataset[T] private[sql](
* Returns the first element in this [[Dataset]].
* @since 1.6.0
*/
- def first(): T = rdd.first()
+ def first(): T = take(1).head
/**
* Returns an array that contains all the elements in this [[Dataset]].
@@ -530,7 +530,14 @@ class Dataset[T] private[sql](
* For Java API, use [[collectAsList]].
* @since 1.6.0
*/
- def collect(): Array[T] = rdd.collect()
+ def collect(): Array[T] = {
+ // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
+ // to convert the rows into objects of type T.
+ val tEnc = resolvedTEncoder
+ val input = queryExecution.analyzed.output
+ val bound = tEnc.bind(input)
+ queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow)
+ }
/**
* Returns an array that contains all the elements in this [[Dataset]].
@@ -541,7 +548,7 @@ class Dataset[T] private[sql](
* For Java API, use [[collectAsList]].
* @since 1.6.0
*/
- def collectAsList(): java.util.List[T] = rdd.collect().toSeq.asJava
+ def collectAsList(): java.util.List[T] = collect().toSeq.asJava
/**
* Returns the first `num` elements of this [[Dataset]] as an array.
@@ -551,7 +558,7 @@ class Dataset[T] private[sql](
*
* @since 1.6.0
*/
- def take(num: Int): Array[T] = rdd.take(num)
+ def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
/**
* Returns the first `num` elements of this [[Dataset]] as an array.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index a3922340cc..ea29428c55 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import java.io.{ObjectInput, ObjectOutput, Externalizable}
+
import scala.language.postfixOps
import org.apache.spark.sql.functions._
@@ -24,6 +26,20 @@ import org.apache.spark.sql.test.SharedSQLContext
case class ClassData(a: String, b: Int)
+/**
+ * A class used to test serialization using encoders. This class throws exceptions when using
+ * Java serialization -- so the only way it can be "serialized" is through our encoders.
+ */
+case class NonSerializableCaseClass(value: String) extends Externalizable {
+ override def readExternal(in: ObjectInput): Unit = {
+ throw new UnsupportedOperationException
+ }
+
+ override def writeExternal(out: ObjectOutput): Unit = {
+ throw new UnsupportedOperationException
+ }
+}
+
class DatasetSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -41,6 +57,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
1, 1, 1)
}
+ test("collect, first, and take should use encoders for serialization") {
+ val item = NonSerializableCaseClass("abcd")
+ val ds = Seq(item).toDS()
+ assert(ds.collect().head == item)
+ assert(ds.collectAsList().get(0) == item)
+ assert(ds.first() == item)
+ assert(ds.take(1).head == item)
+ assert(ds.takeAsList(1).get(0) == item)
+ }
+
test("as tuple") {
val data = Seq(("a", 1), ("b", 2)).toDF("a", "b")
checkAnswer(
@@ -75,6 +101,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
ignore("Dataset should set the resolved encoders internally for maps") {
// TODO: Enable this once we fix SPARK-11793.
+ // We inject a group by here to make sure this test case is future proof
+ // when we implement better pipelining and local execution mode.
val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS()
.map(c => ClassData(c.a, c.b + 1))
.groupBy(p => p).count()
@@ -219,7 +247,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
("a", 30), ("b", 3), ("c", 1))
}
- test("groupBy function, fatMap") {
+ test("groupBy function, flatMap") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy(v => (v._1, "word"))
val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) }