aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-07-01 10:30:54 -0700
committerReynold Xin <rxin@databricks.com>2015-07-01 10:30:54 -0700
commit4137f769b84300648ad933b0b3054d69a7316745 (patch)
treeb7db2ea26789de4b4756525c4573c25c12e167d1 /sql
parent69c5dee2f01b1ae35bd813d31d46429a32cb475d (diff)
downloadspark-4137f769b84300648ad933b0b3054d69a7316745.tar.gz
spark-4137f769b84300648ad933b0b3054d69a7316745.tar.bz2
spark-4137f769b84300648ad933b0b3054d69a7316745.zip
[SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected input types.
This patch doesn't actually introduce any code that uses the new ExpectsInputTypes. It just adds the trait so others can use it. Also renamed the old expectsInputTypes function to just inputTypes. We should add implicit type casting also in the future. Author: Reynold Xin <rxin@databricks.com> Closes #7151 from rxin/expects-input-types and squashes the following commits: 16cf07b [Reynold Xin] [SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected input types.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala10
7 files changed, 44 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index a069b4710f..583338da57 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -26,7 +26,6 @@ import org.apache.spark.sql.types._
* Throws user facing errors when passed invalid queries that fail to analyze.
*/
trait CheckAnalysis {
- self: Analyzer =>
/**
* Override to provide additional checks for correct analysis.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index a9d396d1fa..2ab5cb666f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -45,7 +45,7 @@ object HiveTypeCoercion {
IfCoercion ::
Division ::
PropagateTypes ::
- AddCastForAutoCastInputTypes ::
+ ImplicitTypeCasts ::
Nil
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
@@ -705,13 +705,13 @@ object HiveTypeCoercion {
* Casts types according to the expected input types for Expressions that have the trait
* [[AutoCastInputTypes]].
*/
- object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] {
+ object ImplicitTypeCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes =>
- val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map {
+ case e: AutoCastInputTypes if e.children.map(_.dataType) != e.inputTypes =>
+ val newC = (e.children, e.children.map(_.dataType), e.inputTypes).zipped.map {
case (child, actual, expected) =>
if (actual == expected) child else Cast(child, expected)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index b5063f32fa..e18a311894 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -266,16 +266,37 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
}
/**
+ * An trait that gets mixin to define the expected input types of an expression.
+ */
+trait ExpectsInputTypes { self: Expression =>
+
+ /**
+ * Expected input types from child expressions. The i-th position in the returned seq indicates
+ * the type requirement for the i-th child.
+ *
+ * The possible values at each position are:
+ * 1. a specific data type, e.g. LongType, StringType.
+ * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType.
+ * 3. a list of specific data types, e.g. Seq(StringType, BinaryType).
+ */
+ def inputTypes: Seq[Any]
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ // We will do the type checking in `HiveTypeCoercion`, so always returning success here.
+ TypeCheckResult.TypeCheckSuccess
+ }
+}
+
+/**
* Expressions that require a specific `DataType` as input should implement this trait
* so that the proper type conversions can be performed in the analyzer.
*/
-trait AutoCastInputTypes {
- self: Expression =>
+trait AutoCastInputTypes { self: Expression =>
- def expectedChildTypes: Seq[DataType]
+ def inputTypes: Seq[DataType]
override def checkInputDataTypes(): TypeCheckResult = {
- // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`,
+ // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`,
// so type mismatch error won't be reported here, but for underling `Cast`s.
TypeCheckResult.TypeCheckSuccess
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index da63f2fa97..b51318dd50 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -59,7 +59,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
extends UnaryExpression with Serializable with AutoCastInputTypes {
self: Product =>
- override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
+ override def inputTypes: Seq[DataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
override def nullable: Boolean = true
override def toString: String = s"$name($child)"
@@ -98,7 +98,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product =>
- override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
+ override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
override def toString: String = s"$name($left, $right)"
@@ -210,7 +210,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia
case class Bin(child: Expression)
extends UnaryExpression with Serializable with AutoCastInputTypes {
- override def expectedChildTypes: Seq[DataType] = Seq(LongType)
+ override def inputTypes: Seq[DataType] = Seq(LongType)
override def dataType: DataType = StringType
override def eval(input: InternalRow): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index a7bcbe46c3..407023e472 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -36,7 +36,7 @@ case class Md5(child: Expression)
override def dataType: DataType = StringType
- override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
+ override def inputTypes: Seq[DataType] = Seq(BinaryType)
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
@@ -68,7 +68,7 @@ case class Sha2(left: Expression, right: Expression)
override def toString: String = s"SHA2($left, $right)"
- override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
+ override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
override def eval(input: InternalRow): Any = {
val evalE1 = left.eval(input)
@@ -151,7 +151,7 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp
override def dataType: DataType = StringType
- override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
+ override def inputTypes: Seq[DataType] = Seq(BinaryType)
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
@@ -179,7 +179,7 @@ case class Crc32(child: Expression)
override def dataType: DataType = LongType
- override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
+ override def inputTypes: Seq[DataType] = Seq(BinaryType)
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 98cd5aa814..a777f77add 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -72,7 +72,7 @@ trait PredicateHelper {
case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes {
override def toString: String = s"NOT $child"
- override def expectedChildTypes: Seq[DataType] = Seq(BooleanType)
+ override def inputTypes: Seq[DataType] = Seq(BooleanType)
override def eval(input: InternalRow): Any = {
child.eval(input) match {
@@ -122,7 +122,7 @@ case class InSet(value: Expression, hset: Set[Any])
case class And(left: Expression, right: Expression)
extends BinaryExpression with Predicate with AutoCastInputTypes {
- override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
+ override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
override def symbol: String = "&&"
@@ -171,7 +171,7 @@ case class And(left: Expression, right: Expression)
case class Or(left: Expression, right: Expression)
extends BinaryExpression with Predicate with AutoCastInputTypes {
- override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
+ override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
override def symbol: String = "||"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index ce184e4f32..4cbfc4e084 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -32,7 +32,7 @@ trait StringRegexExpression extends AutoCastInputTypes {
override def nullable: Boolean = left.nullable || right.nullable
override def dataType: DataType = BooleanType
- override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
// try cache the pattern for Literal
private lazy val cache: Pattern = right match {
@@ -117,7 +117,7 @@ trait CaseConversionExpression extends AutoCastInputTypes {
def convert(v: UTF8String): UTF8String
override def dataType: DataType = StringType
- override def expectedChildTypes: Seq[DataType] = Seq(StringType)
+ override def inputTypes: Seq[DataType] = Seq(StringType)
override def eval(input: InternalRow): Any = {
val evaluated = child.eval(input)
@@ -165,7 +165,7 @@ trait StringComparison extends AutoCastInputTypes {
override def nullable: Boolean = left.nullable || right.nullable
- override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def eval(input: InternalRow): Any = {
val leftEval = left.eval(input)
@@ -238,7 +238,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
if (str.dataType == BinaryType) str.dataType else StringType
}
- override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
+ override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
override def children: Seq[Expression] = str :: pos :: len :: Nil
@@ -297,7 +297,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
*/
case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes {
override def dataType: DataType = IntegerType
- override def expectedChildTypes: Seq[DataType] = Seq(StringType)
+ override def inputTypes: Seq[DataType] = Seq(StringType)
override def eval(input: InternalRow): Any = {
val string = child.eval(input)