aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala9
3 files changed, 39 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index a373714832..0ebc3d180a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -87,6 +87,18 @@ trait CheckAnalysis {
s"join condition '${condition.prettyString}' " +
s"of type ${condition.dataType.simpleString} is not a boolean.")
+ case j @ Join(_, _, _, Some(condition)) =>
+ def checkValidJoinConditionExprs(expr: Expression): Unit = expr match {
+ case p: Predicate =>
+ p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs)
+ case e if e.dataType.isInstanceOf[BinaryType] =>
+ failAnalysis(s"expression ${e.prettyString} in join condition " +
+ s"'${condition.prettyString}' can't be binary type.")
+ case _ => // OK
+ }
+
+ checkValidJoinConditionExprs(condition)
+
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
@@ -100,7 +112,15 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}
+ def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match {
+ case BinaryType =>
+ failAnalysis(s"grouping expression '${expr.prettyString}' in aggregate can " +
+ s"not be binary type.")
+ case _ => // OK
+ }
+
aggregateExprs.foreach(checkValidAggregateExpression)
+ aggregateExprs.foreach(checkValidGroupingExprs)
case Sort(orders, _, _) =>
orders.foreach { order =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index b26d3ab253..228ece8065 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DecimalType
+import org.apache.spark.sql.types.{BinaryType, DecimalType}
class DataFrameAggregateSuite extends QueryTest {
@@ -191,4 +191,13 @@ class DataFrameAggregateSuite extends QueryTest {
Row(null))
}
+ test("aggregation can't work on binary type") {
+ val df = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType)
+ intercept[AnalysisException] {
+ df.groupBy("c").agg(count("*"))
+ }
+ intercept[AnalysisException] {
+ df.distinct
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 666f26bf62..27c08f6464 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.joins._
+import org.apache.spark.sql.types.BinaryType
class JoinSuite extends QueryTest with BeforeAndAfterEach {
@@ -489,4 +490,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(3, 2) :: Nil)
}
+
+ test("Join can't work on binary type") {
+ val left = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType)
+ val right = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("d").select($"d" cast BinaryType)
+ intercept[AnalysisException] {
+ left.join(right, ($"left.N" === $"right.N"), "full")
+ }
+ }
}