diff options
Diffstat (limited to 'sql/catalyst')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 11 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 |
2 files changed, 13 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 831fb4fe95..96e2aee4de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -69,6 +69,7 @@ trait HiveTypeCoercion { val typeCoercionRules = PropagateTypes :: ConvertNaNs :: + InConversion :: WidenTypes :: PromoteStrings :: DecimalPrecision :: @@ -287,6 +288,16 @@ trait HiveTypeCoercion { } } + /** + * Convert all expressions in in() list to the left operator type + */ + object InConversion extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => + i.makeCopy(Array(a, b.map(Cast(_, a.dataType)))) + } + } + // scalastyle:off /** * Calculates and propagates precision for fixed-precision decimals. Hive has a number of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 709f7d672d..e4a60f53d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -310,8 +310,8 @@ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) => - val hSet = list.map(e => e.eval(null)) - InSet(v, HashSet() ++ hSet) + val hSet = list.map(e => e.eval(null)) + InSet(v, HashSet() ++ hSet) } } } |