aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTakeshi YAMAMURO <linguin.m.s@gmail.com>2016-06-24 21:07:03 -0700
committerHerman van Hovell <hvanhovell@databricks.com>2016-06-24 21:07:03 -0700
commitd2e44d7db82ff3c3326af7bf7ea69c803803698e (patch)
treea2ad958646e3532e7e3f101aa0a4517a4386c3e3
parent9053054c7f5ec2b9e3d8efbe6bfbfa68a6d1f0d0 (diff)
downloadspark-d2e44d7db82ff3c3326af7bf7ea69c803803698e.tar.gz
spark-d2e44d7db82ff3c3326af7bf7ea69c803803698e.tar.bz2
spark-d2e44d7db82ff3c3326af7bf7ea69c803803698e.zip
[SPARK-16192][SQL] Add type checks in CollectSet
## What changes were proposed in this pull request? `CollectSet` cannot have map-typed data because MapTypeData does not implement `equals`. So, this pr is to add type checks in `CheckAnalysis`. ## How was this patch tested? Added tests to check failures when we found map-typed data in `CollectSet`. Author: Takeshi YAMAMURO <linguin.m.s@gmail.com> Closes #13892 from maropu/SPARK-16192.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala10
3 files changed, 21 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 899227674f..ac9693e079 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -73,9 +73,9 @@ trait CheckAnalysis extends PredicateHelper {
s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")
case g: Grouping =>
- failAnalysis(s"grouping() can only be used with GroupingSets/Cube/Rollup")
+ failAnalysis("grouping() can only be used with GroupingSets/Cube/Rollup")
case g: GroupingID =>
- failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ failAnalysis("grouping_id() can only be used with GroupingSets/Cube/Rollup")
case w @ WindowExpression(AggregateExpression(_, _, true, _), _) =>
failAnalysis(s"Distinct window functions are not supported: $w")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index 1f4ff9c4b1..ac2cefaddc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import scala.collection.generic.Growable
import scala.collection.mutable
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.InternalRow
@@ -107,6 +108,14 @@ case class CollectSet(
def this(child: Expression) = this(child, 0, 0)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure("collect_set() cannot have map type data")
+ }
+ }
+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 69a990789b..92aa7b9543 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -457,6 +457,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}
+ test("collect_set functions cannot have maps") {
+ val df = Seq((1, 3, 0), (2, 3, 0), (3, 4, 1))
+ .toDF("a", "x", "y")
+ .select($"a", map($"x", $"y").as("b"))
+ val error = intercept[AnalysisException] {
+ df.select(collect_set($"a"), collect_set($"b"))
+ }
+ assert(error.message.contains("collect_set() cannot have map type data"))
+ }
+
test("SPARK-14664: Decimal sum/avg over window should work.") {
checkAnswer(
spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),