aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala9
3 files changed, 22 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 831fb4fe95..96e2aee4de 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -69,6 +69,7 @@ trait HiveTypeCoercion {
val typeCoercionRules =
PropagateTypes ::
ConvertNaNs ::
+ InConversion ::
WidenTypes ::
PromoteStrings ::
DecimalPrecision ::
@@ -287,6 +288,16 @@ trait HiveTypeCoercion {
}
}
+ /**
+ * Convert all expressions in in() list to the left operator type
+ */
+ object InConversion extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
+ i.makeCopy(Array(a, b.map(Cast(_, a.dataType))))
+ }
+ }
+
// scalastyle:off
/**
* Calculates and propagates precision for fixed-precision decimals. Hive has a number of
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 709f7d672d..e4a60f53d6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -310,8 +310,8 @@ object OptimizeIn extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
- val hSet = list.map(e => e.eval(null))
- InSet(v, HashSet() ++ hSet)
+ val hSet = list.map(e => e.eval(null))
+ InSet(v, HashSet() ++ hSet)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 0ab8558c1d..208cec6a32 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -120,6 +120,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(1, 1) :: Nil)
}
+ test("SPARK-6201 IN type conversion") {
+ jsonRDD(sparkContext.parallelize(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}")))
+ .registerTempTable("d")
+
+ checkAnswer(
+ sql("select * from d where d.a in (1,2)"),
+ Seq(Row("1"), Row("2")))
+ }
+
test("SPARK-3176 Added Parser of SQL ABS()") {
checkAnswer(
sql("SELECT ABS(-1.3)"),