aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-05-04 18:03:07 -0700
committerReynold Xin <rxin@databricks.com>2015-05-04 18:03:07 -0700
commit678c4da0fa1bbfb6b5a0d3aced7aefa1bbbc193c (patch)
tree868d4afbdbfcf366483db247a22a572f4816d6ea /sql
parent80554111703c08e2bedbe303e04ecd162ec119e1 (diff)
downloadspark-678c4da0fa1bbfb6b5a0d3aced7aefa1bbbc193c.tar.gz
spark-678c4da0fa1bbfb6b5a0d3aced7aefa1bbbc193c.tar.bz2
spark-678c4da0fa1bbfb6b5a0d3aced7aefa1bbbc193c.zip
[SPARK-7266] Add ExpectsInputTypes to expressions when possible.
This should gives us better analysis time error messages (rather than runtime) and automatic type casting. Author: Reynold Xin <rxin@databricks.com> Closes #5796 from rxin/expected-input-types and squashes the following commits: c900760 [Reynold Xin] [SPARK-7266] Add ExpectsInputTypes to expressions when possible.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala58
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala40
5 files changed, 71 insertions, 56 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 73c9a1c7af..831fb4fe95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -239,37 +239,43 @@ trait HiveTypeCoercion {
a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
// we should cast all timestamp/date/string compare into string compare
- case p: BinaryPredicate if p.left.dataType == StringType
- && p.right.dataType == DateType =>
+ case p: BinaryComparison if p.left.dataType == StringType &&
+ p.right.dataType == DateType =>
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
- case p: BinaryPredicate if p.left.dataType == DateType
- && p.right.dataType == StringType =>
+ case p: BinaryComparison if p.left.dataType == DateType &&
+ p.right.dataType == StringType =>
p.makeCopy(Array(Cast(p.left, StringType), p.right))
- case p: BinaryPredicate if p.left.dataType == StringType
- && p.right.dataType == TimestampType =>
+ case p: BinaryComparison if p.left.dataType == StringType &&
+ p.right.dataType == TimestampType =>
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
- case p: BinaryPredicate if p.left.dataType == TimestampType
- && p.right.dataType == StringType =>
+ case p: BinaryComparison if p.left.dataType == TimestampType &&
+ p.right.dataType == StringType =>
p.makeCopy(Array(Cast(p.left, StringType), p.right))
- case p: BinaryPredicate if p.left.dataType == TimestampType
- && p.right.dataType == DateType =>
+ case p: BinaryComparison if p.left.dataType == TimestampType &&
+ p.right.dataType == DateType =>
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
- case p: BinaryPredicate if p.left.dataType == DateType
- && p.right.dataType == TimestampType =>
+ case p: BinaryComparison if p.left.dataType == DateType &&
+ p.right.dataType == TimestampType =>
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
- case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType =>
+ case p: BinaryComparison if p.left.dataType == StringType &&
+ p.right.dataType != StringType =>
p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
- case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType =>
+ case p: BinaryComparison if p.left.dataType != StringType &&
+ p.right.dataType == StringType =>
p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))
- case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) =>
+ case i @ In(a, b) if a.dataType == DateType &&
+ b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
- case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
+ case i @ In(a, b) if a.dataType == TimestampType &&
+ b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
- case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) =>
+ case i @ In(a, b) if a.dataType == DateType &&
+ b.forall(_.dataType == TimestampType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
- case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) =>
+ case i @ In(a, b) if a.dataType == TimestampType &&
+ b.forall(_.dataType == DateType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case Sum(e) if e.dataType == StringType =>
@@ -420,19 +426,19 @@ trait HiveTypeCoercion {
)
case LessThan(e1 @ DecimalType.Expression(p1, s1),
- e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+ e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
- e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+ e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
- e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+ e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
- e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+ e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
@@ -481,8 +487,8 @@ trait HiveTypeCoercion {
// No need to change the EqualNullSafe operators, too
case e: EqualNullSafe => e
// Otherwise turn them to Byte types so that there exists and ordering.
- case p: BinaryComparison
- if p.left.dataType == BooleanType && p.right.dataType == BooleanType =>
+ case p: BinaryComparison if p.left.dataType == BooleanType &&
+ p.right.dataType == BooleanType =>
p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType)))
}
}
@@ -564,10 +570,6 @@ trait HiveTypeCoercion {
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
- // Compatible with Hive
- case Substring(e, start, len) if e.dataType != StringType =>
- Substring(Cast(e, StringType), start, len)
-
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 1d71c1b4b0..4fd1bc4dd6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
@@ -86,6 +85,8 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def foldable: Boolean = left.foldable && right.foldable
+ override def nullable: Boolean = left.nullable || right.nullable
+
override def toString: String = s"($left $symbol $right)"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 140ccd8d37..c7a37ad966 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -74,14 +74,12 @@ abstract class BinaryArithmetic extends BinaryExpression {
type EvaluatedType = Any
- def nullable: Boolean = left.nullable || right.nullable
-
override lazy val resolved =
left.resolved && right.resolved &&
left.dataType == right.dataType &&
!DecimalType.isFixed(left.dataType)
- def dataType: DataType = {
+ override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this,
s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 9cb00cb273..26c38c56c0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -70,16 +70,14 @@ trait PredicateHelper {
expr.references.subsetOf(plan.outputSet)
}
-abstract class BinaryPredicate extends BinaryExpression with Predicate {
- self: Product =>
- override def nullable: Boolean = left.nullable || right.nullable
-}
-case class Not(child: Expression) extends UnaryExpression with Predicate {
+case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes {
override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable
override def toString: String = s"NOT $child"
+ override def expectedChildTypes: Seq[DataType] = Seq(BooleanType)
+
override def eval(input: Row): Any = {
child.eval(input) match {
case null => null
@@ -120,7 +118,11 @@ case class InSet(value: Expression, hset: Set[Any])
}
}
-case class And(left: Expression, right: Expression) extends BinaryPredicate {
+case class And(left: Expression, right: Expression)
+ extends BinaryExpression with Predicate with ExpectsInputTypes {
+
+ override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
+
override def symbol: String = "&&"
override def eval(input: Row): Any = {
@@ -142,7 +144,11 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate {
}
}
-case class Or(left: Expression, right: Expression) extends BinaryPredicate {
+case class Or(left: Expression, right: Expression)
+ extends BinaryExpression with Predicate with ExpectsInputTypes {
+
+ override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
+
override def symbol: String = "||"
override def eval(input: Row): Any = {
@@ -164,7 +170,7 @@ case class Or(left: Expression, right: Expression) extends BinaryPredicate {
}
}
-abstract class BinaryComparison extends BinaryPredicate {
+abstract class BinaryComparison extends BinaryExpression with Predicate {
self: Product =>
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index d597bf7ce7..d6f23df30f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -22,7 +22,7 @@ import java.util.regex.Pattern
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.types._
-trait StringRegexExpression {
+trait StringRegexExpression extends ExpectsInputTypes {
self: BinaryExpression =>
type EvaluatedType = Any
@@ -32,6 +32,7 @@ trait StringRegexExpression {
override def nullable: Boolean = left.nullable || right.nullable
override def dataType: DataType = BooleanType
+ override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
// try cache the pattern for Literal
private lazy val cache: Pattern = right match {
@@ -57,11 +58,11 @@ trait StringRegexExpression {
if(r == null) {
null
} else {
- val regex = pattern(r.asInstanceOf[UTF8String].toString)
+ val regex = pattern(r.asInstanceOf[UTF8String].toString())
if(regex == null) {
null
} else {
- matches(regex, l.asInstanceOf[UTF8String].toString)
+ matches(regex, l.asInstanceOf[UTF8String].toString())
}
}
}
@@ -110,7 +111,7 @@ case class RLike(left: Expression, right: Expression)
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
}
-trait CaseConversionExpression {
+trait CaseConversionExpression extends ExpectsInputTypes {
self: UnaryExpression =>
type EvaluatedType = Any
@@ -118,8 +119,9 @@ trait CaseConversionExpression {
def convert(v: UTF8String): UTF8String
override def foldable: Boolean = child.foldable
- def nullable: Boolean = child.nullable
- def dataType: DataType = StringType
+ override def nullable: Boolean = child.nullable
+ override def dataType: DataType = StringType
+ override def expectedChildTypes: Seq[DataType] = Seq(StringType)
override def eval(input: Row): Any = {
val evaluated = child.eval(input)
@@ -136,7 +138,7 @@ trait CaseConversionExpression {
*/
case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {
- override def convert(v: UTF8String): UTF8String = v.toUpperCase
+ override def convert(v: UTF8String): UTF8String = v.toUpperCase()
override def toString: String = s"Upper($child)"
}
@@ -146,21 +148,21 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
*/
case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {
- override def convert(v: UTF8String): UTF8String = v.toLowerCase
+ override def convert(v: UTF8String): UTF8String = v.toLowerCase()
override def toString: String = s"Lower($child)"
}
/** A base trait for functions that compare two strings, returning a boolean. */
trait StringComparison {
- self: BinaryPredicate =>
+ self: BinaryExpression =>
+
+ def compare(l: UTF8String, r: UTF8String): Boolean
override type EvaluatedType = Any
override def nullable: Boolean = left.nullable || right.nullable
- def compare(l: UTF8String, r: UTF8String): Boolean
-
override def eval(input: Row): Any = {
val leftEval = left.eval(input)
if(leftEval == null) {
@@ -181,31 +183,35 @@ trait StringComparison {
* A function that returns true if the string `left` contains the string `right`.
*/
case class Contains(left: Expression, right: Expression)
- extends BinaryPredicate with StringComparison {
+ extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
+ override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
}
/**
* A function that returns true if the string `left` starts with the string `right`.
*/
case class StartsWith(left: Expression, right: Expression)
- extends BinaryPredicate with StringComparison {
+ extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
+ override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
}
/**
* A function that returns true if the string `left` ends with the string `right`.
*/
case class EndsWith(left: Expression, right: Expression)
- extends BinaryPredicate with StringComparison {
+ extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
+ override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
}
/**
* A function that takes a substring of its first argument starting at a given position.
* Defined for String and Binary types.
*/
-case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression {
+case class Substring(str: Expression, pos: Expression, len: Expression)
+ extends Expression with ExpectsInputTypes {
type EvaluatedType = Any
@@ -219,6 +225,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
if (str.dataType == BinaryType) str.dataType else StringType
}
+ override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
+
override def children: Seq[Expression] = str :: pos :: len :: Nil
@inline
@@ -258,7 +266,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
val (st, end) = slicePos(start, length, () => ba.length)
ba.slice(st, end)
case s: UTF8String =>
- val (st, end) = slicePos(start, length, () => s.length)
+ val (st, end) = slicePos(start, length, () => s.length())
s.slice(st, end)
}
}