aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-06-13 09:58:48 -0700
committerHerman van Hovell <hvanhovell@databricks.com>2016-06-13 09:58:48 -0700
commitcd47e233749f42b016264569a214cbf67f45f436 (patch)
tree96fdef37e4746dc7e06535f9822538458536134a
parentd681742b2d37bd68cf5d8d3161e0f48846f6f9d4 (diff)
downloadspark-cd47e233749f42b016264569a214cbf67f45f436.tar.gz
spark-cd47e233749f42b016264569a214cbf67f45f436.tar.bz2
spark-cd47e233749f42b016264569a214cbf67f45f436.zip
[SPARK-15814][SQL] Aggregator can return null result
## What changes were proposed in this pull request? It's similar to the bug fixed in https://github.com/apache/spark/pull/13425, we should consider null object and wrap the `CreateStruct` with `If` to do null check. This PR also improves the test framework to test the objects of `Dataset[T]` directly, instead of calling `toDF` and compare the rows. ## How was this patch tested? new test in `DatasetAggregatorSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #13553 from cloud-fan/agg-null.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala23
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala38
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala95
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala4
8 files changed, 117 insertions, 64 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index ecb56e2a28..8bdfa48a30 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -127,7 +127,12 @@ case class TypedAggregateExpression(
dataType match {
case s: StructType =>
- ReferenceToExpressions(CreateStruct(outputSerializer), resultObj :: Nil)
+ val objRef = outputSerializer.head.find(_.isInstanceOf[BoundReference]).get
+ val struct = If(
+ IsNull(objRef),
+ Literal.create(null, dataType),
+ CreateStruct(outputSerializer))
+ ReferenceToExpressions(struct, resultObj :: Nil)
case _ =>
assert(outputSerializer.length == 1)
outputSerializer.head transform {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index f9b4cd83c3..f955120dc5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -115,11 +115,23 @@ object RowAgg extends Aggregator[Row, Int, Int] {
override def outputEncoder: Encoder[Int] = Encoders.scalaInt
}
+object NullResultAgg extends Aggregator[AggData, AggData, AggData] {
+ override def zero: AggData = AggData(0, "")
+ override def reduce(b: AggData, a: AggData): AggData = AggData(b.a + a.a, b.b + a.b)
+ override def finish(reduction: AggData): AggData = {
+ if (reduction.a % 2 == 0) null else reduction
+ }
+ override def merge(b1: AggData, b2: AggData): AggData = AggData(b1.a + b2.a, b1.b + b2.b)
+ override def bufferEncoder: Encoder[AggData] = Encoders.product[AggData]
+ override def outputEncoder: Encoder[AggData] = Encoders.product[AggData]
+}
-class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
+class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
+ private implicit val ordering = Ordering.by((c: AggData) => c.a -> c.b)
+
test("typed aggregation: TypedAggregator") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
@@ -204,7 +216,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn),
(1.5, 2))
- checkDataset(
+ checkDatasetUnorderly(
ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn),
("one", 1), ("two", 1))
}
@@ -271,4 +283,11 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
"RowAgg(org.apache.spark.sql.Row)")
assert(df.groupBy($"j").agg(RowAgg.toColumn as "agg1").columns.last == "agg1")
}
+
+ test("SPARK-15814 Aggregator can return null result") {
+ val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()
+ checkDatasetUnorderly(
+ ds.groupByKey(_.a).agg(NullResultAgg.toColumn),
+ 1 -> AggData(1, "one"), 2 -> null)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index a634502e2e..6aa3d3fe80 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -82,7 +82,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
test("groupBy function, keys") {
val ds = Seq(1, 2, 3, 4, 5).toDS()
val grouped = ds.groupByKey(_ % 2)
- checkDataset(
+ checkDatasetUnorderly(
grouped.keys,
0, 1)
}
@@ -95,7 +95,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
(name, iter.size)
}
- checkDataset(
+ checkDatasetUnorderly(
agged,
("even", 5), ("odd", 6))
}
@@ -105,7 +105,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
val grouped = ds.groupByKey(_.length)
val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) }
- checkDataset(
+ checkDatasetUnorderly(
agged,
"1", "abc", "3", "xyz", "5", "hello")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 4536a7356f..96d85f12e8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -32,6 +32,8 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT
class DatasetSuite extends QueryTest with SharedSQLContext {
import testImplicits._
+ private implicit val ordering = Ordering.by((c: ClassData) => c.a -> c.b)
+
test("toDS") {
val data = Seq(("a", 1), ("b", 2), ("c", 3))
checkDataset(
@@ -95,12 +97,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}
assert(ds.repartition(10).rdd.partitions.length == 10)
- checkDataset(
+ checkDatasetUnorderly(
ds.repartition(10),
data: _*)
assert(ds.coalesce(1).rdd.partitions.length == 1)
- checkDataset(
+ checkDatasetUnorderly(
ds.coalesce(1),
data: _*)
}
@@ -163,7 +165,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
.map(c => ClassData(c.a, c.b + 1))
.groupByKey(p => p).count()
- checkDataset(
+ checkDatasetUnorderly(
ds,
(ClassData("one", 2), 1L), (ClassData("two", 3), 1L))
}
@@ -204,7 +206,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("select 2, primitive and class, fields reordered") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
- checkDecoding(
+ checkDataset(
ds.select(
expr("_1").as[String],
expr("named_struct('b', _2, 'a', _1)").as[ClassData]),
@@ -291,7 +293,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy function, keys") {
val ds = Seq(("a", 1), ("b", 1)).toDS()
val grouped = ds.groupByKey(v => (1, v._2))
- checkDataset(
+ checkDatasetUnorderly(
grouped.keys,
(1, 1))
}
@@ -301,7 +303,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val grouped = ds.groupByKey(v => (v._1, "word"))
val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) }
- checkDataset(
+ checkDatasetUnorderly(
agged,
("a", 30), ("b", 3), ("c", 1))
}
@@ -313,7 +315,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
Iterator(g._1, iter.map(_._2).sum.toString)
}
- checkDataset(
+ checkDatasetUnorderly(
agged,
"a", "30", "b", "3", "c", "1")
}
@@ -322,7 +324,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds = Seq("abc", "xyz", "hello").toDS()
val agged = ds.groupByKey(_.length).reduceGroups(_ + _)
- checkDataset(
+ checkDatasetUnorderly(
agged,
3 -> "abcxyz", 5 -> "hello")
}
@@ -340,7 +342,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("typed aggregation: expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- checkDataset(
+ checkDatasetUnorderly(
ds.groupByKey(_._1).agg(sum("_2").as[Long]),
("a", 30L), ("b", 3L), ("c", 1L))
}
@@ -348,7 +350,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("typed aggregation: expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- checkDataset(
+ checkDatasetUnorderly(
ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]),
("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L))
}
@@ -356,7 +358,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("typed aggregation: expr, expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- checkDataset(
+ checkDatasetUnorderly(
ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")),
("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L))
}
@@ -364,7 +366,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("typed aggregation: expr, expr, expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- checkDataset(
+ checkDatasetUnorderly(
ds.groupByKey(_._1).agg(
sum("_2").as[Long],
sum($"_2" + 1).as[Long],
@@ -380,7 +382,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString))
}
- checkDataset(
+ checkDatasetUnorderly(
cogrouped,
1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er")
}
@@ -392,7 +394,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString))
}
- checkDataset(
+ checkDatasetUnorderly(
cogrouped,
1 -> "a", 2 -> "bc", 3 -> "d")
}
@@ -482,8 +484,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkDataset(
ds1.joinWith(ds2, lit(true)),
((nullInt, "1"), (nullInt, "1")),
- ((new java.lang.Integer(22), "2"), (nullInt, "1")),
((nullInt, "1"), (new java.lang.Integer(22), "2")),
+ ((new java.lang.Integer(22), "2"), (nullInt, "1")),
((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2")))
}
@@ -776,9 +778,9 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds1 = ds.as("d1")
val ds2 = ds.as("d2")
- checkDataset(ds1.joinWith(ds2, $"d1.value" === $"d2.value"), (2, 2), (3, 3), (4, 4))
- checkDataset(ds1.intersect(ds2), 2, 3, 4)
- checkDataset(ds1.except(ds1))
+ checkDatasetUnorderly(ds1.joinWith(ds2, $"d1.value" === $"d2.value"), (2, 2), (3, 3), (4, 4))
+ checkDatasetUnorderly(ds1.intersect(ds2), 2, 3, 4)
+ checkDatasetUnorderly(ds1.except(ds1))
}
test("SPARK-15441: Dataset outer join") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index acb59d46e1..742f036e55 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -68,28 +68,62 @@ abstract class QueryTest extends PlanTest {
/**
* Evaluates a dataset to make sure that the result of calling collect matches the given
* expected answer.
- * - Special handling is done based on whether the query plan should be expected to return
- * the results in sorted order.
- * - This function also checks to make sure that the schema for serializing the expected answer
- * matches that produced by the dataset (i.e. does manual construction of object match
- * the constructed encoder for cases like joins, etc). Note that this means that it will fail
- * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead
- * which performs a subset of the checks done by this function.
*/
protected def checkDataset[T](
- ds: Dataset[T],
+ ds: => Dataset[T],
expectedAnswer: T*): Unit = {
- checkAnswer(
- ds.toDF(),
- spark.createDataset(expectedAnswer)(ds.exprEnc).toDF().collect().toSeq)
+ val result = getResult(ds)
- checkDecoding(ds, expectedAnswer: _*)
+ if (!compare(result.toSeq, expectedAnswer)) {
+ fail(
+ s"""
+ |Decoded objects do not match expected objects:
+ |expected: $expectedAnswer
+ |actual: ${result.toSeq}
+ |${ds.exprEnc.deserializer.treeString}
+ """.stripMargin)
+ }
}
- protected def checkDecoding[T](
+ /**
+ * Evaluates a dataset to make sure that the result of calling collect matches the given
+ * expected answer, after sort.
+ */
+ protected def checkDatasetUnorderly[T : Ordering](
ds: => Dataset[T],
expectedAnswer: T*): Unit = {
- val decoded = try ds.collect().toSet catch {
+ val result = getResult(ds)
+
+ if (!compare(result.toSeq.sorted, expectedAnswer.sorted)) {
+ fail(
+ s"""
+ |Decoded objects do not match expected objects:
+ |expected: $expectedAnswer
+ |actual: ${result.toSeq}
+ |${ds.exprEnc.deserializer.treeString}
+ """.stripMargin)
+ }
+ }
+
+ private def getResult[T](ds: => Dataset[T]): Array[T] = {
+ val analyzedDS = try ds catch {
+ case ae: AnalysisException =>
+ if (ae.plan.isDefined) {
+ fail(
+ s"""
+ |Failed to analyze query: $ae
+ |${ae.plan.get}
+ |
+ |${stackTraceToString(ae)}
+ """.stripMargin)
+ } else {
+ throw ae
+ }
+ }
+ checkJsonFormat(analyzedDS)
+ assertEmptyMissingInput(analyzedDS)
+
+ try ds.collect() catch {
case e: Exception =>
fail(
s"""
@@ -99,24 +133,17 @@ abstract class QueryTest extends PlanTest {
|${ds.queryExecution}
""".stripMargin, e)
}
+ }
- // Handle the case where the return type is an array
- val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false)
- def normalEquality = decoded == expectedAnswer.toSet
- def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet
- def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq)
-
- if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) {
- val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted
- val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted
-
- val comparison = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n")
- fail(
- s"""Decoded objects do not match expected objects:
- |$comparison
- |${ds.exprEnc.deserializer.treeString}
- """.stripMargin)
- }
+ private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match {
+ case (null, null) => true
+ case (null, _) => false
+ case (_, null) => false
+ case (a: Array[_], b: Array[_]) =>
+ a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)}
+ case (a: Iterable[_], b: Iterable[_]) =>
+ a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)}
+ case (a, b) => a == b
}
/**
@@ -143,7 +170,7 @@ abstract class QueryTest extends PlanTest {
checkJsonFormat(analyzedDF)
- assertEmptyMissingInput(df)
+ assertEmptyMissingInput(analyzedDF)
QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
case Some(errorMessage) => fail(errorMessage)
@@ -201,10 +228,10 @@ abstract class QueryTest extends PlanTest {
planWithCaching)
}
- private def checkJsonFormat(df: DataFrame): Unit = {
+ private def checkJsonFormat(ds: Dataset[_]): Unit = {
// Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that
// RDD and Data resolution does not break.
- val logicalPlan = df.queryExecution.analyzed
+ val logicalPlan = ds.queryExecution.analyzed
// bypass some cases that we can't handle currently.
logicalPlan.transform {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
index 4ed517cb26..71d3da9158 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
@@ -132,9 +132,9 @@ class TextSuite extends QueryTest with SharedSQLContext {
ds1.write.text(s"$path/part=a")
ds1.write.text(s"$path/part=b")
- checkDataset(
+ checkAnswer(
spark.read.format("text").load(path).select($"part"),
- Row("a"), Row("b"))
+ Row("a") :: Row("b") :: Nil)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
index 1c73208736..bb3063dc34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
@@ -140,7 +140,7 @@ class FileStreamSinkSuite extends StreamTest {
}
val outputDf = spark.read.parquet(outputDir).as[Int]
- checkDataset(outputDf, 1, 2, 3)
+ checkDatasetUnorderly(outputDf, 1, 2, 3)
} finally {
if (query != null) {
@@ -191,7 +191,7 @@ class FileStreamSinkSuite extends StreamTest {
assert(hadoopdFsRelations.head.dataSchema.exists(_.name == "value"))
// Verify the data is correctly read
- checkDataset(
+ checkDatasetUnorderly(
outputDf.as[(Int, Int)],
(1000, 1), (2000, 2), (3000, 3))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala
index df76499fa2..9aada0b18d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala
@@ -174,13 +174,13 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
input.addData(1, 2, 3)
query.processAllAvailable()
- checkDataset(
+ checkDatasetUnorderly(
spark.table("memStream").as[(Int, Long)],
(1, 1L), (2, 1L), (3, 1L))
input.addData(4, 5, 6)
query.processAllAvailable()
- checkDataset(
+ checkDatasetUnorderly(
spark.table("memStream").as[(Int, Long)],
(1, 1L), (2, 1L), (3, 1L), (4, 1L), (5, 1L), (6, 1L))