aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala10
7 files changed, 44 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index a069b4710f..583338da57 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -26,7 +26,6 @@ import org.apache.spark.sql.types._
* Throws user facing errors when passed invalid queries that fail to analyze.
*/
trait CheckAnalysis {
- self: Analyzer =>
/**
* Override to provide additional checks for correct analysis.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index a9d396d1fa..2ab5cb666f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -45,7 +45,7 @@ object HiveTypeCoercion {
IfCoercion ::
Division ::
PropagateTypes ::
- AddCastForAutoCastInputTypes ::
+ ImplicitTypeCasts ::
Nil
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
@@ -705,13 +705,13 @@ object HiveTypeCoercion {
* Casts types according to the expected input types for Expressions that have the trait
* [[AutoCastInputTypes]].
*/
- object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] {
+ object ImplicitTypeCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes =>
- val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map {
+ case e: AutoCastInputTypes if e.children.map(_.dataType) != e.inputTypes =>
+ val newC = (e.children, e.children.map(_.dataType), e.inputTypes).zipped.map {
case (child, actual, expected) =>
if (actual == expected) child else Cast(child, expected)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index b5063f32fa..e18a311894 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -266,16 +266,37 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
}
/**
+ * An trait that gets mixin to define the expected input types of an expression.
+ */
+trait ExpectsInputTypes { self: Expression =>
+
+ /**
+ * Expected input types from child expressions. The i-th position in the returned seq indicates
+ * the type requirement for the i-th child.
+ *
+ * The possible values at each position are:
+ * 1. a specific data type, e.g. LongType, StringType.
+ * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType.
+ * 3. a list of specific data types, e.g. Seq(StringType, BinaryType).
+ */
+ def inputTypes: Seq[Any]
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ // We will do the type checking in `HiveTypeCoercion`, so always returning success here.
+ TypeCheckResult.TypeCheckSuccess
+ }
+}
+
+/**
* Expressions that require a specific `DataType` as input should implement this trait
* so that the proper type conversions can be performed in the analyzer.
*/
-trait AutoCastInputTypes {
- self: Expression =>
+trait AutoCastInputTypes { self: Expression =>
- def expectedChildTypes: Seq[DataType]
+ def inputTypes: Seq[DataType]
override def checkInputDataTypes(): TypeCheckResult = {
- // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`,
+ // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`,
// so type mismatch error won't be reported here, but for underling `Cast`s.
TypeCheckResult.TypeCheckSuccess
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index da63f2fa97..b51318dd50 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -59,7 +59,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
extends UnaryExpression with Serializable with AutoCastInputTypes {
self: Product =>
- override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
+ override def inputTypes: Seq[DataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
override def nullable: Boolean = true
override def toString: String = s"$name($child)"
@@ -98,7 +98,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product =>
- override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
+ override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
override def toString: String = s"$name($left, $right)"
@@ -210,7 +210,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia
case class Bin(child: Expression)
extends UnaryExpression with Serializable with AutoCastInputTypes {
- override def expectedChildTypes: Seq[DataType] = Seq(LongType)
+ override def inputTypes: Seq[DataType] = Seq(LongType)
override def dataType: DataType = StringType
override def eval(input: InternalRow): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index a7bcbe46c3..407023e472 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -36,7 +36,7 @@ case class Md5(child: Expression)
override def dataType: DataType = StringType
- override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
+ override def inputTypes: Seq[DataType] = Seq(BinaryType)
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
@@ -68,7 +68,7 @@ case class Sha2(left: Expression, right: Expression)
override def toString: String = s"SHA2($left, $right)"
- override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
+ override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
override def eval(input: InternalRow): Any = {
val evalE1 = left.eval(input)
@@ -151,7 +151,7 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp
override def dataType: DataType = StringType
- override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
+ override def inputTypes: Seq[DataType] = Seq(BinaryType)
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
@@ -179,7 +179,7 @@ case class Crc32(child: Expression)
override def dataType: DataType = LongType
- override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
+ override def inputTypes: Seq[DataType] = Seq(BinaryType)
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 98cd5aa814..a777f77add 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -72,7 +72,7 @@ trait PredicateHelper {
case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes {
override def toString: String = s"NOT $child"
- override def expectedChildTypes: Seq[DataType] = Seq(BooleanType)
+ override def inputTypes: Seq[DataType] = Seq(BooleanType)
override def eval(input: InternalRow): Any = {
child.eval(input) match {
@@ -122,7 +122,7 @@ case class InSet(value: Expression, hset: Set[Any])
case class And(left: Expression, right: Expression)
extends BinaryExpression with Predicate with AutoCastInputTypes {
- override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
+ override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
override def symbol: String = "&&"
@@ -171,7 +171,7 @@ case class And(left: Expression, right: Expression)
case class Or(left: Expression, right: Expression)
extends BinaryExpression with Predicate with AutoCastInputTypes {
- override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
+ override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
override def symbol: String = "||"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index ce184e4f32..4cbfc4e084 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -32,7 +32,7 @@ trait StringRegexExpression extends AutoCastInputTypes {
override def nullable: Boolean = left.nullable || right.nullable
override def dataType: DataType = BooleanType
- override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
// try cache the pattern for Literal
private lazy val cache: Pattern = right match {
@@ -117,7 +117,7 @@ trait CaseConversionExpression extends AutoCastInputTypes {
def convert(v: UTF8String): UTF8String
override def dataType: DataType = StringType
- override def expectedChildTypes: Seq[DataType] = Seq(StringType)
+ override def inputTypes: Seq[DataType] = Seq(StringType)
override def eval(input: InternalRow): Any = {
val evaluated = child.eval(input)
@@ -165,7 +165,7 @@ trait StringComparison extends AutoCastInputTypes {
override def nullable: Boolean = left.nullable || right.nullable
- override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def eval(input: InternalRow): Any = {
val leftEval = left.eval(input)
@@ -238,7 +238,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
if (str.dataType == BinaryType) str.dataType else StringType
}
- override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
+ override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
override def children: Seq[Expression] = str :: pos :: len :: Nil
@@ -297,7 +297,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
*/
case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes {
override def dataType: DataType = IntegerType
- override def expectedChildTypes: Seq[DataType] = Seq(StringType)
+ override def inputTypes: Seq[DataType] = Seq(StringType)
override def eval(input: InternalRow): Any = {
val string = child.eval(input)