aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-07-15 22:27:39 -0700
committerReynold Xin <rxin@databricks.com>2015-07-15 22:27:39 -0700
commitba33096846dc8061e97a7bf8f3b46f899d530159 (patch)
treee4e28d2a0c552dc42fd995b1241067298282bc16 /sql
parent42dea3acf90ec506a0b79720b55ae1d753cc7544 (diff)
downloadspark-ba33096846dc8061e97a7bf8f3b46f899d530159.tar.gz
spark-ba33096846dc8061e97a7bf8f3b46f899d530159.tar.bz2
spark-ba33096846dc8061e97a7bf8f3b46f899d530159.zip
[SPARK-9068][SQL] refactor the implicit type cast code
based on https://github.com/apache/spark/pull/7348 Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7420 from cloud-fan/type-check and squashes the following commits: 7633fa9 [Wenchen Fan] revert fe169b0 [Wenchen Fan] improve test 03b70da [Wenchen Fan] enhance implicit type cast
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala33
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala45
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala75
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala10
13 files changed, 81 insertions, 126 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 25087915b5..50db7d21f0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -675,10 +675,10 @@ object HiveTypeCoercion {
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType =>
if (b.inputType.acceptsType(commonType)) {
- // If the expression accepts the tighest common type, cast to that.
+ // If the expression accepts the tightest common type, cast to that.
val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
- b.makeCopy(Array(newLeft, newRight))
+ b.withNewChildren(Seq(newLeft, newRight))
} else {
// Otherwise, don't do anything with the expression.
b
@@ -697,7 +697,7 @@ object HiveTypeCoercion {
// general implicit casting.
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
if (in.dataType == NullType && !expected.acceptsType(NullType)) {
- Cast(in, expected.defaultConcreteType)
+ Literal.create(null, expected.defaultConcreteType)
} else {
in
}
@@ -719,27 +719,22 @@ object HiveTypeCoercion {
@Nullable val ret: Expression = (inType, expectedType) match {
// If the expected type is already a parent of the input type, no need to cast.
- case _ if expectedType.isSameType(inType) => e
+ case _ if expectedType.acceptsType(inType) => e
// Cast null type (usually from null literals) into target types
case (NullType, target) => Cast(e, target.defaultConcreteType)
- // If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is
- // already a number, leave it as is.
- case (_: NumericType, NumericType) => e
-
// If the function accepts any numeric type and the input is a string, we follow the hive
// convention and cast that input into a double
case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType)
- // Implicit cast among numeric types
+ // Implicit cast among numeric types. When we reach here, input type is not acceptable.
+
// If input is a numeric type but not decimal, and we expect a decimal type,
// cast the input to unlimited precision decimal.
- case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] =>
- Cast(e, DecimalType.Unlimited)
+ case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited)
// For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
- case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target)
- case (_: NumericType, target: NumericType) => e
+ case (_: NumericType, target: NumericType) => Cast(e, target)
// Implicit cast between date time types
case (DateType, TimestampType) => Cast(e, TimestampType)
@@ -753,15 +748,9 @@ object HiveTypeCoercion {
case (StringType, BinaryType) => Cast(e, BinaryType)
case (any, StringType) if any != StringType => Cast(e, StringType)
- // Type collection.
- // First see if we can find our input type in the type collection. If we can, then just
- // use the current expression; otherwise, find the first one we can implicitly cast.
- case (_, TypeCollection(types)) =>
- if (types.exists(_.isSameType(inType))) {
- e
- } else {
- types.flatMap(implicitCast(e, _)).headOption.orNull
- }
+ // When we reach here, input type is not acceptable for any types in this type collection,
+ // try to find the first one we can implicitly cast.
+ case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull
// Else, just return the same input expression
case _ => null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 87667316ac..a655cc8e48 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -386,17 +386,15 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)
override def checkInputDataTypes(): TypeCheckResult = {
- // First call the checker for ExpectsInputTypes, and then check whether left and right have
- // the same type.
- super.checkInputDataTypes() match {
- case TypeCheckResult.TypeCheckSuccess =>
- if (left.dataType != right.dataType) {
- TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
- s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
- } else {
- TypeCheckResult.TypeCheckSuccess
- }
- case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg)
+ // First check whether left and right have the same type, then check if the type is acceptable.
+ if (left.dataType != right.dataType) {
+ TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
+ s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
+ } else if (!inputType.acceptsType(left.dataType)) {
+ TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," +
+ s" not ${left.dataType.simpleString}")
+ } else {
+ TypeCheckResult.TypeCheckSuccess
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 394ef556e0..382cbe3b84 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -320,7 +320,6 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
override def symbol: String = "max"
- override def prettyName: String = symbol
}
case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
@@ -375,7 +374,6 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
override def symbol: String = "min"
- override def prettyName: String = symbol
}
case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
index af1abbcd22..a1e48c4210 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
*/
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
- override def inputType: AbstractDataType = TypeCollection.Bitwise
+ override def inputType: AbstractDataType = IntegralType
override def symbol: String = "&"
@@ -53,7 +53,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
*/
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
- override def inputType: AbstractDataType = TypeCollection.Bitwise
+ override def inputType: AbstractDataType = IntegralType
override def symbol: String = "|"
@@ -78,7 +78,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
- override def inputType: AbstractDataType = TypeCollection.Bitwise
+ override def inputType: AbstractDataType = IntegralType
override def symbol: String = "^"
@@ -101,7 +101,7 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
*/
case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes {
- override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.Bitwise)
+ override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
override def dataType: DataType = child.dataType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
index c7f039ede2..9162b73fe5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
@@ -35,8 +35,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
TypeCheckResult.TypeCheckFailure(
s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
} else if (trueValue.dataType != falseValue.dataType) {
- TypeCheckResult.TypeCheckFailure(
- s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).")
+ TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
+ s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
} else {
TypeCheckResult.TypeCheckSuccess
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index f5715f7a82..076d7b5a51 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -34,32 +34,18 @@ private[sql] abstract class AbstractDataType {
private[sql] def defaultConcreteType: DataType
/**
- * Returns true if this data type is the same type as `other`. This is different that equality
- * as equality will also consider data type parametrization, such as decimal precision.
+ * Returns true if `other` is an acceptable input type for a function that expects this,
+ * possibly abstract DataType.
*
* {{{
* // this should return true
- * DecimalType.isSameType(DecimalType(10, 2))
- *
- * // this should return false
- * NumericType.isSameType(DecimalType(10, 2))
- * }}}
- */
- private[sql] def isSameType(other: DataType): Boolean
-
- /**
- * Returns true if `other` is an acceptable input type for a function that expectes this,
- * possibly abstract, DataType.
- *
- * {{{
- * // this should return true
- * DecimalType.isSameType(DecimalType(10, 2))
+ * DecimalType.acceptsType(DecimalType(10, 2))
*
* // this should return true as well
* NumericType.acceptsType(DecimalType(10, 2))
* }}}
*/
- private[sql] def acceptsType(other: DataType): Boolean = isSameType(other)
+ private[sql] def acceptsType(other: DataType): Boolean
/** Readable string representation for the type. */
private[sql] def simpleString: String
@@ -83,10 +69,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType
- override private[sql] def isSameType(other: DataType): Boolean = false
-
override private[sql] def acceptsType(other: DataType): Boolean =
- types.exists(_.isSameType(other))
+ types.exists(_.acceptsType(other))
override private[sql] def simpleString: String = {
types.map(_.simpleString).mkString("(", " or ", ")")
@@ -107,13 +91,6 @@ private[sql] object TypeCollection {
TimestampType, DateType,
StringType, BinaryType)
- /**
- * Types that can be used in bitwise operations.
- */
- val Bitwise = TypeCollection(
- BooleanType,
- ByteType, ShortType, IntegerType, LongType)
-
def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)
def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match {
@@ -134,8 +111,6 @@ protected[sql] object AnyDataType extends AbstractDataType {
override private[sql] def simpleString: String = "any"
- override private[sql] def isSameType(other: DataType): Boolean = false
-
override private[sql] def acceptsType(other: DataType): Boolean = true
}
@@ -183,13 +158,11 @@ private[sql] object NumericType extends AbstractDataType {
override private[sql] def simpleString: String = "numeric"
- override private[sql] def isSameType(other: DataType): Boolean = false
-
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType]
}
-private[sql] object IntegralType {
+private[sql] object IntegralType extends AbstractDataType {
/**
* Enables matching against IntegralType for expressions:
* {{{
@@ -198,6 +171,12 @@ private[sql] object IntegralType {
* }}}
*/
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType]
+
+ override private[sql] def defaultConcreteType: DataType = IntegerType
+
+ override private[sql] def simpleString: String = "integral"
+
+ override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 76ca7a84c1..5094058164 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -28,7 +28,7 @@ object ArrayType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
- override private[sql] def isSameType(other: DataType): Boolean = {
+ override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[ArrayType]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index da83a7f0ba..2d133eea19 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = this
- override private[sql] def isSameType(other: DataType): Boolean = this == other
+ override private[sql] def acceptsType(other: DataType): Boolean = this == other
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index a1cafeab17..377c75f6e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -86,7 +86,7 @@ object DecimalType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = Unlimited
- override private[sql] def isSameType(other: DataType): Boolean = {
+ override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[DecimalType]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
index ddead10bc2..ac34b64282 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -71,7 +71,7 @@ object MapType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType)
- override private[sql] def isSameType(other: DataType): Boolean = {
+ override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[MapType]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index b8097403ec..2ef97a427c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -307,7 +307,7 @@ object StructType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = new StructType
- override private[sql] def isSameType(other: DataType): Boolean = {
+ override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[StructType]
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index a4ce1825ca..ed0d20e7de 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types.{TypeCollection, StringType}
class ExpressionTypeCheckingSuite extends SparkFunSuite {
@@ -49,23 +49,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
def assertErrorForDifferingTypes(expr: Expression): Unit = {
assertError(expr,
- s"differing types in '${expr.prettyString}' (int and boolean)")
- }
-
- def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): Unit = {
- val e = intercept[AnalysisException] {
- assertSuccess(expr)
- }
- assert(e.getMessage.contains(errorMessage))
+ s"differing types in '${expr.prettyString}'")
}
test("check types for unary arithmetic") {
assertError(UnaryMinus('stringField), "expected to be of type numeric")
assertError(Abs('stringField), "expected to be of type numeric")
- assertError(BitwiseNot('stringField), "type (boolean or tinyint or smallint or int or bigint)")
+ assertError(BitwiseNot('stringField), "expected to be of type integral")
}
- ignore("check types for binary arithmetic") {
+ test("check types for binary arithmetic") {
// We will cast String to Double for binary arithmetic
assertSuccess(Add('intField, 'stringField))
assertSuccess(Subtract('intField, 'stringField))
@@ -85,21 +78,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(MaxOf('intField, 'booleanField))
assertErrorForDifferingTypes(MinOf('intField, 'booleanField))
- assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type")
- assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type")
- assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type")
- assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type")
- assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type")
+ assertError(Add('booleanField, 'booleanField), "accepts numeric type")
+ assertError(Subtract('booleanField, 'booleanField), "accepts numeric type")
+ assertError(Multiply('booleanField, 'booleanField), "accepts numeric type")
+ assertError(Divide('booleanField, 'booleanField), "accepts numeric type")
+ assertError(Remainder('booleanField, 'booleanField), "accepts numeric type")
- assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type")
- assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type")
- assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type")
+ assertError(BitwiseAnd('booleanField, 'booleanField), "accepts integral type")
+ assertError(BitwiseOr('booleanField, 'booleanField), "accepts integral type")
+ assertError(BitwiseXor('booleanField, 'booleanField), "accepts integral type")
- assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type")
- assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type")
+ assertError(MaxOf('complexField, 'complexField),
+ s"accepts ${TypeCollection.Ordered.simpleString} type")
+ assertError(MinOf('complexField, 'complexField),
+ s"accepts ${TypeCollection.Ordered.simpleString} type")
}
- ignore("check types for predicates") {
+ test("check types for predicates") {
// We will cast String to Double for binary comparison
assertSuccess(EqualTo('intField, 'stringField))
assertSuccess(EqualNullSafe('intField, 'stringField))
@@ -112,25 +107,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(EqualTo('intField, 'booleanField))
assertSuccess(EqualNullSafe('intField, 'booleanField))
- assertError(EqualTo('intField, 'complexField), "differing types")
- assertError(EqualNullSafe('intField, 'complexField), "differing types")
-
+ assertErrorForDifferingTypes(EqualTo('intField, 'complexField))
+ assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField))
assertErrorForDifferingTypes(LessThan('intField, 'booleanField))
assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
- assertError(
- LessThan('complexField, 'complexField), "operator < accepts non-complex type")
- assertError(
- LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type")
- assertError(
- GreaterThan('complexField, 'complexField), "operator > accepts non-complex type")
- assertError(
- GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type")
+ assertError(LessThan('complexField, 'complexField),
+ s"accepts ${TypeCollection.Ordered.simpleString} type")
+ assertError(LessThanOrEqual('complexField, 'complexField),
+ s"accepts ${TypeCollection.Ordered.simpleString} type")
+ assertError(GreaterThan('complexField, 'complexField),
+ s"accepts ${TypeCollection.Ordered.simpleString} type")
+ assertError(GreaterThanOrEqual('complexField, 'complexField),
+ s"accepts ${TypeCollection.Ordered.simpleString} type")
- assertError(
- If('intField, 'stringField, 'stringField),
+ assertError(If('intField, 'stringField, 'stringField),
"type of predicate expression in If should be boolean")
assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField))
@@ -180,12 +173,12 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
}
test("check types for ROUND") {
- assertErrorWithImplicitCast(Round(Literal(null), 'booleanField),
- "data type mismatch: argument 2 is expected to be of type int")
- assertErrorWithImplicitCast(Round(Literal(null), 'complexField),
- "data type mismatch: argument 2 is expected to be of type int")
assertSuccess(Round(Literal(null), Literal(null)))
- assertError(Round('booleanField, 'intField),
- "data type mismatch: argument 1 is expected to be of type numeric")
+ assertSuccess(Round('intField, Literal(1)))
+
+ assertError(Round('intField, 'intField), "Only foldable Expression is allowed")
+ assertError(Round('intField, 'booleanField), "expected to be of type int")
+ assertError(Round('intField, 'complexField), "expected to be of type int")
+ assertError(Round('booleanField, 'intField), "expected to be of type numeric")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 8e9b20a3eb..d0fd033b98 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -203,7 +203,7 @@ class HiveTypeCoercionSuite extends PlanTest {
ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
NumericTypeUnaryExpression(Literal.create(null, NullType)),
- NumericTypeUnaryExpression(Cast(Literal.create(null, NullType), DoubleType)))
+ NumericTypeUnaryExpression(Literal.create(null, DoubleType)))
}
test("cast NullType for binary operators") {
@@ -215,9 +215,7 @@ class HiveTypeCoercionSuite extends PlanTest {
ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
- NumericTypeBinaryOperator(
- Cast(Literal.create(null, NullType), DoubleType),
- Cast(Literal.create(null, NullType), DoubleType)))
+ NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType)))
}
test("coalesce casts") {
@@ -345,14 +343,14 @@ object HiveTypeCoercionSuite {
}
case class AnyTypeBinaryOperator(left: Expression, right: Expression)
- extends BinaryOperator with ExpectsInputTypes {
+ extends BinaryOperator {
override def dataType: DataType = NullType
override def inputType: AbstractDataType = AnyDataType
override def symbol: String = "anytype"
}
case class NumericTypeBinaryOperator(left: Expression, right: Expression)
- extends BinaryOperator with ExpectsInputTypes {
+ extends BinaryOperator {
override def dataType: DataType = NullType
override def inputType: AbstractDataType = NumericType
override def symbol: String = "numerictype"