aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-06-24 16:26:00 -0700
committerMichael Armbrust <michael@databricks.com>2015-06-24 16:26:00 -0700
commitb71d3254e50838ccae43bdb0ff186fda25f03152 (patch)
tree9a22820e88da2970e4144cb8f82db9309ef38a9f
parent7daa70292e47be6a944351ef00c770ad4bcb0877 (diff)
downloadspark-b71d3254e50838ccae43bdb0ff186fda25f03152.tar.gz
spark-b71d3254e50838ccae43bdb0ff186fda25f03152.tar.bz2
spark-b71d3254e50838ccae43bdb0ff186fda25f03152.zip
[SPARK-8075] [SQL] apply type check interface to more expressions
a follow up of https://github.com/apache/spark/pull/6405. Note: It's not a big change, a lot of changing is due to I swap some code in `aggregates.scala` to make aggregate functions right below its corresponding aggregate expressions. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #6723 from cloud-fan/type-check and squashes the following commits: 2124301 [Wenchen Fan] fix tests 5a658bb [Wenchen Fan] add tests 287d3bb [Wenchen Fan] apply type check interface to more expressions
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala420
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala30
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala9
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala)26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala6
21 files changed, 337 insertions, 290 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b06759f144..cad2c2abe6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -587,8 +587,8 @@ class Analyzer(
failAnalysis(
s"""Expect multiple names given for ${g.getClass.getName},
|but only single name '${name}' specified""".stripMargin)
- case Alias(g: Generator, name) => Some((g, name :: Nil))
- case MultiAlias(g: Generator, names) => Some(g, names)
+ case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil))
+ case MultiAlias(g: Generator, names) if g.resolved => Some(g, names)
case _ => None
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index d4ab1fc643..4ef7341a33 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -317,6 +317,7 @@ trait HiveTypeCoercion {
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
+ case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
}
}
@@ -590,11 +591,12 @@ trait HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case a @ CreateArray(children) if !a.resolved =>
- val commonType = a.childTypes.reduce(
- (a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType))
- CreateArray(
- children.map(c => if (c.dataType == commonType) c else Cast(c, commonType)))
+ case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 =>
+ val types = children.map(_.dataType)
+ findTightestCommonTypeAndPromoteToString(types) match {
+ case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
+ case None => a
+ }
// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
@@ -620,12 +622,11 @@ trait HiveTypeCoercion {
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
- case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
+ case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
val types = es.map(_.dataType)
findTightestCommonTypeAndPromoteToString(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
- case None =>
- sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}")
+ case None => c
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d271434a30..8bd7fc18a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}
import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -31,7 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String
/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
- override lazy val resolved = childrenResolved && resolve(child.dataType, dataType)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (resolve(child.dataType, dataType)) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(
+ s"cannot cast ${child.dataType} to $dataType")
+ }
+ }
override def foldable: Boolean = child.foldable
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index a10a959ae7..f59db3d5df 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -162,9 +162,7 @@ abstract class Expression extends TreeNode[Expression] {
/**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
- * Note: it's not valid to call this method until `childrenResolved == true`
- * TODO: we should remove the default implementation and implement it for all
- * expressions with proper error message.
+ * Note: it's not valid to call this method until `childrenResolved == true`.
*/
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
index 4d6c1c2651..4d7c95ffd1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
@@ -96,6 +96,11 @@ object ExtractValue {
}
}
+/**
+ * A common interface of all kinds of extract value expressions.
+ * Note: concrete extract value expressions are created only by `ExtractValue.apply`,
+ * we don't need to do type check for them.
+ */
trait ExtractValue extends UnaryExpression {
self: Product =>
}
@@ -179,9 +184,6 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
- override lazy val resolved = childrenResolved &&
- child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType]
-
protected def evalNotNull(value: Any, ordinal: Any) = {
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
@@ -203,8 +205,6 @@ case class GetMapValue(child: Expression, ordinal: Expression)
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
- override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType]
-
protected def evalNotNull(value: Any, ordinal: Any) = {
val baseValue = value.asInstanceOf[Map[Any, _]]
baseValue.get(ordinal).orNull
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 00d2e499c5..a9fc54c548 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions
import com.clearspring.analytics.stream.cardinality.HyperLogLog
-import org.apache.spark.sql.catalyst
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
@@ -101,6 +102,9 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[
}
override def newInstance(): MinFunction = new MinFunction(child, this)
+
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForOrderingExpr(child.dataType, "function min")
}
case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@@ -132,6 +136,9 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[
}
override def newInstance(): MaxFunction = new MaxFunction(child, this)
+
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForOrderingExpr(child.dataType, "function max")
}
case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@@ -165,6 +172,21 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
override def newInstance(): CountFunction = new CountFunction(child, this)
}
+case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+ def this() = this(null, null) // Required for serialization.
+
+ var count: Long = _
+
+ override def update(input: InternalRow): Unit = {
+ val evaluatedExpr = expr.eval(input)
+ if (evaluatedExpr != null) {
+ count += 1L
+ }
+ }
+
+ override def eval(input: InternalRow): Any = count
+}
+
case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate {
def this() = this(null)
@@ -183,6 +205,28 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
}
}
+case class CountDistinctFunction(
+ @transient expr: Seq[Expression],
+ @transient base: AggregateExpression)
+ extends AggregateFunction {
+
+ def this() = this(null, null) // Required for serialization.
+
+ val seen = new OpenHashSet[Any]()
+
+ @transient
+ val distinctValue = new InterpretedProjection(expr)
+
+ override def update(input: InternalRow): Unit = {
+ val evaluatedExpr = distinctValue(input)
+ if (!evaluatedExpr.anyNull) {
+ seen.add(evaluatedExpr)
+ }
+ }
+
+ override def eval(input: InternalRow): Any = seen.size.toLong
+}
+
case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression {
def this() = this(null)
@@ -278,6 +322,25 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
}
}
+case class ApproxCountDistinctPartitionFunction(
+ expr: Expression,
+ base: AggregateExpression,
+ relativeSD: Double)
+ extends AggregateFunction {
+ def this() = this(null, null, 0) // Required for serialization.
+
+ private val hyperLogLog = new HyperLogLog(relativeSD)
+
+ override def update(input: InternalRow): Unit = {
+ val evaluatedExpr = expr.eval(input)
+ if (evaluatedExpr != null) {
+ hyperLogLog.offer(evaluatedExpr)
+ }
+ }
+
+ override def eval(input: InternalRow): Any = hyperLogLog
+}
+
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
@@ -289,6 +352,23 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
}
}
+case class ApproxCountDistinctMergeFunction(
+ expr: Expression,
+ base: AggregateExpression,
+ relativeSD: Double)
+ extends AggregateFunction {
+ def this() = this(null, null, 0) // Required for serialization.
+
+ private val hyperLogLog = new HyperLogLog(relativeSD)
+
+ override def update(input: InternalRow): Unit = {
+ val evaluatedExpr = expr.eval(input)
+ hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
+ }
+
+ override def eval(input: InternalRow): Any = hyperLogLog.cardinality()
+}
+
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
extends PartialAggregate with trees.UnaryNode[Expression] {
@@ -349,159 +429,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
}
override def newInstance(): AverageFunction = new AverageFunction(child, this)
-}
-
-case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
-
- override def nullable: Boolean = true
-
- override def dataType: DataType = child.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
- case DecimalType.Unlimited =>
- DecimalType.Unlimited
- case _ =>
- child.dataType
- }
-
- override def toString: String = s"SUM($child)"
-
- override def asPartial: SplitEvaluation = {
- child.dataType match {
- case DecimalType.Fixed(_, _) =>
- val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
- SplitEvaluation(
- Cast(CombineSum(partialSum.toAttribute), dataType),
- partialSum :: Nil)
-
- case _ =>
- val partialSum = Alias(Sum(child), "PartialSum")()
- SplitEvaluation(
- CombineSum(partialSum.toAttribute),
- partialSum :: Nil)
- }
- }
-
- override def newInstance(): SumFunction = new SumFunction(child, this)
-}
-
-/**
- * Sum should satisfy 3 cases:
- * 1) sum of all null values = zero
- * 2) sum for table column with no data = null
- * 3) sum of column with null and not null values = sum of not null values
- * Require separate CombineSum Expression and function as it has to distinguish "No data" case
- * versus "data equals null" case, while aggregating results and at each partial expression.i.e.,
- * Combining PartitionLevel InputData
- * <-- null
- * Zero <-- Zero <-- null
- *
- * <-- null <-- no data
- * null <-- null <-- no data
- */
-case class CombineSum(child: Expression) extends AggregateExpression {
- def this() = this(null)
-
- override def children: Seq[Expression] = child :: Nil
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
- override def toString: String = s"CombineSum($child)"
- override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
-}
-
-case class SumDistinct(child: Expression)
- extends PartialAggregate with trees.UnaryNode[Expression] {
-
- def this() = this(null)
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType match {
- case DecimalType.Fixed(precision, scale) =>
- DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
- case DecimalType.Unlimited =>
- DecimalType.Unlimited
- case _ =>
- child.dataType
- }
- override def toString: String = s"SUM(DISTINCT $child)"
- override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this)
-
- override def asPartial: SplitEvaluation = {
- val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")()
- SplitEvaluation(
- CombineSetsAndSum(partialSet.toAttribute, this),
- partialSet :: Nil)
- }
-}
-case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression {
- def this() = this(null, null)
-
- override def children: Seq[Expression] = inputSet :: Nil
- override def nullable: Boolean = true
- override def dataType: DataType = base.dataType
- override def toString: String = s"CombineAndSum($inputSet)"
- override def newInstance(): CombineSetsAndSumFunction = {
- new CombineSetsAndSumFunction(inputSet, this)
- }
-}
-
-case class CombineSetsAndSumFunction(
- @transient inputSet: Expression,
- @transient base: AggregateExpression)
- extends AggregateFunction {
-
- def this() = this(null, null) // Required for serialization.
-
- val seen = new OpenHashSet[Any]()
-
- override def update(input: InternalRow): Unit = {
- val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
- val inputIterator = inputSetEval.iterator
- while (inputIterator.hasNext) {
- seen.add(inputIterator.next)
- }
- }
-
- override def eval(input: InternalRow): Any = {
- val casted = seen.asInstanceOf[OpenHashSet[InternalRow]]
- if (casted.size == 0) {
- null
- } else {
- Cast(Literal(
- casted.iterator.map(f => f.apply(0)).reduceLeft(
- base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
- base.dataType).eval(null)
- }
- }
-}
-
-case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
- override def toString: String = s"FIRST($child)"
-
- override def asPartial: SplitEvaluation = {
- val partialFirst = Alias(First(child), "PartialFirst")()
- SplitEvaluation(
- First(partialFirst.toAttribute),
- partialFirst :: Nil)
- }
- override def newInstance(): FirstFunction = new FirstFunction(child, this)
-}
-
-case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references: AttributeSet = child.references
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
- override def toString: String = s"LAST($child)"
-
- override def asPartial: SplitEvaluation = {
- val partialLast = Alias(Last(child), "PartialLast")()
- SplitEvaluation(
- Last(partialLast.toAttribute),
- partialLast :: Nil)
- }
- override def newInstance(): LastFunction = new LastFunction(child, this)
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function average")
}
case class AverageFunction(expr: Expression, base: AggregateExpression)
@@ -551,55 +481,41 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
}
}
-case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
- def this() = this(null, null) // Required for serialization.
+case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- var count: Long = _
+ override def nullable: Boolean = true
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- count += 1L
- }
+ override def dataType: DataType = child.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
+ case DecimalType.Unlimited =>
+ DecimalType.Unlimited
+ case _ =>
+ child.dataType
}
- override def eval(input: InternalRow): Any = count
-}
-
-case class ApproxCountDistinctPartitionFunction(
- expr: Expression,
- base: AggregateExpression,
- relativeSD: Double)
- extends AggregateFunction {
- def this() = this(null, null, 0) // Required for serialization.
+ override def toString: String = s"SUM($child)"
- private val hyperLogLog = new HyperLogLog(relativeSD)
+ override def asPartial: SplitEvaluation = {
+ child.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
+ SplitEvaluation(
+ Cast(CombineSum(partialSum.toAttribute), dataType),
+ partialSum :: Nil)
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- if (evaluatedExpr != null) {
- hyperLogLog.offer(evaluatedExpr)
+ case _ =>
+ val partialSum = Alias(Sum(child), "PartialSum")()
+ SplitEvaluation(
+ CombineSum(partialSum.toAttribute),
+ partialSum :: Nil)
}
}
- override def eval(input: InternalRow): Any = hyperLogLog
-}
-
-case class ApproxCountDistinctMergeFunction(
- expr: Expression,
- base: AggregateExpression,
- relativeSD: Double)
- extends AggregateFunction {
- def this() = this(null, null, 0) // Required for serialization.
-
- private val hyperLogLog = new HyperLogLog(relativeSD)
-
- override def update(input: InternalRow): Unit = {
- val evaluatedExpr = expr.eval(input)
- hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
- }
+ override def newInstance(): SumFunction = new SumFunction(child, this)
- override def eval(input: InternalRow): Any = hyperLogLog.cardinality()
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function sum")
}
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@@ -633,6 +549,30 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
}
}
+/**
+ * Sum should satisfy 3 cases:
+ * 1) sum of all null values = zero
+ * 2) sum for table column with no data = null
+ * 3) sum of column with null and not null values = sum of not null values
+ * Require separate CombineSum Expression and function as it has to distinguish "No data" case
+ * versus "data equals null" case, while aggregating results and at each partial expression.i.e.,
+ * Combining PartitionLevel InputData
+ * <-- null
+ * Zero <-- Zero <-- null
+ *
+ * <-- null <-- no data
+ * null <-- null <-- no data
+ */
+case class CombineSum(child: Expression) extends AggregateExpression {
+ def this() = this(null)
+
+ override def children: Seq[Expression] = child :: Nil
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"CombineSum($child)"
+ override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
+}
+
case class CombineSumFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
@@ -670,6 +610,33 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression)
}
}
+case class SumDistinct(child: Expression)
+ extends PartialAggregate with trees.UnaryNode[Expression] {
+
+ def this() = this(null)
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
+ case DecimalType.Unlimited =>
+ DecimalType.Unlimited
+ case _ =>
+ child.dataType
+ }
+ override def toString: String = s"SUM(DISTINCT $child)"
+ override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this)
+
+ override def asPartial: SplitEvaluation = {
+ val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")()
+ SplitEvaluation(
+ CombineSetsAndSum(partialSet.toAttribute, this),
+ partialSet :: Nil)
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct")
+}
+
case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
@@ -696,8 +663,20 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
}
}
-case class CountDistinctFunction(
- @transient expr: Seq[Expression],
+case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression {
+ def this() = this(null, null)
+
+ override def children: Seq[Expression] = inputSet :: Nil
+ override def nullable: Boolean = true
+ override def dataType: DataType = base.dataType
+ override def toString: String = s"CombineAndSum($inputSet)"
+ override def newInstance(): CombineSetsAndSumFunction = {
+ new CombineSetsAndSumFunction(inputSet, this)
+ }
+}
+
+case class CombineSetsAndSumFunction(
+ @transient inputSet: Expression,
@transient base: AggregateExpression)
extends AggregateFunction {
@@ -705,17 +684,39 @@ case class CountDistinctFunction(
val seen = new OpenHashSet[Any]()
- @transient
- val distinctValue = new InterpretedProjection(expr)
-
override def update(input: InternalRow): Unit = {
- val evaluatedExpr = distinctValue(input)
- if (!evaluatedExpr.anyNull) {
- seen.add(evaluatedExpr)
+ val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
+ val inputIterator = inputSetEval.iterator
+ while (inputIterator.hasNext) {
+ seen.add(inputIterator.next)
}
}
- override def eval(input: InternalRow): Any = seen.size.toLong
+ override def eval(input: InternalRow): Any = {
+ val casted = seen.asInstanceOf[OpenHashSet[InternalRow]]
+ if (casted.size == 0) {
+ null
+ } else {
+ Cast(Literal(
+ casted.iterator.map(f => f.apply(0)).reduceLeft(
+ base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
+ base.dataType).eval(null)
+ }
+ }
+}
+
+case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"FIRST($child)"
+
+ override def asPartial: SplitEvaluation = {
+ val partialFirst = Alias(First(child), "PartialFirst")()
+ SplitEvaluation(
+ First(partialFirst.toAttribute),
+ partialFirst :: Nil)
+ }
+ override def newInstance(): FirstFunction = new FirstFunction(child, this)
}
case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@@ -732,6 +733,21 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
override def eval(input: InternalRow): Any = result
}
+case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
+ override def references: AttributeSet = child.references
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"LAST($child)"
+
+ override def asPartial: SplitEvaluation = {
+ val partialLast = Alias(Last(child), "PartialLast")()
+ SplitEvaluation(
+ Last(partialLast.toAttribute),
+ partialLast :: Nil)
+ }
+ override def newInstance(): LastFunction = new LastFunction(child, this)
+}
+
case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index ace8427c8d..3d4d9e2d79 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -25,8 +25,6 @@ import org.apache.spark.sql.types._
abstract class UnaryArithmetic extends UnaryExpression {
self: Product =>
- override def foldable: Boolean = child.foldable
- override def nullable: Boolean = child.nullable
override def dataType: DataType = child.dataType
override def eval(input: InternalRow): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index e0bf07ed18..5def57b067 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
-
/**
* Returns an Array containing the evaluation of all children expressions.
*/
@@ -27,15 +28,12 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
- lazy val childTypes = children.map(_.dataType).distinct
-
- override lazy val resolved =
- childrenResolved && childTypes.size <= 1
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array")
override def dataType: DataType = {
- assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}")
ArrayType(
- childTypes.headOption.getOrElse(NullType),
+ children.headOption.map(_.dataType).getOrElse(NullType),
containsNull = children.exists(_.nullable))
}
@@ -56,19 +54,15 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
- override lazy val resolved: Boolean = childrenResolved
-
override lazy val dataType: StructType = {
- assert(resolved,
- s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.")
- val fields = children.zipWithIndex.map { case (child, idx) =>
- child match {
- case ne: NamedExpression =>
- StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
- case _ =>
- StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
- }
+ val fields = children.zipWithIndex.map { case (child, idx) =>
+ child match {
+ case ne: NamedExpression =>
+ StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
+ case _ =>
+ StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
}
+ }
StructType(fields)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
index 2bc893af02..f5c2dde191 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -17,16 +17,17 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.types._
-/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */
+/**
+ * Return the unscaled Long value of a Decimal, assuming it fits in a Long.
+ * Note: this expression is internal and created only by the optimizer,
+ * we don't need to do type check for it.
+ */
case class UnscaledValue(child: Expression) extends UnaryExpression {
override def dataType: DataType = LongType
- override def foldable: Boolean = child.foldable
- override def nullable: Boolean = child.nullable
override def toString: String = s"UnscaledValue($child)"
override def eval(input: InternalRow): Any = {
@@ -43,12 +44,14 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
}
}
-/** Create a Decimal from an unscaled Long value */
+/**
+ * Create a Decimal from an unscaled Long value.
+ * Note: this expression is internal and created only by the optimizer,
+ * we don't need to do type check for it.
+ */
case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {
override def dataType: DataType = DecimalType(precision, scale)
- override def foldable: Boolean = child.foldable
- override def nullable: Boolean = child.nullable
override def toString: String = s"MakeDecimal($child,$precision,$scale)"
override def eval(input: InternalRow): Decimal = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index f30cb42d12..356560e54c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
-import org.apache.spark.sql.catalyst
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees}
import org.apache.spark.sql.types._
@@ -100,9 +100,14 @@ case class UserDefinedGenerator(
case class Explode(child: Expression)
extends Generator with trees.UnaryNode[Expression] {
- override lazy val resolved =
- child.resolved &&
- (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(
+ s"input to function explode should be array or map type, not ${child.dataType}")
+ }
+ }
override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match {
case ArrayType(et, containsNull) => (et, containsNull) :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 250564dc4b..5694afc61b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions
import java.lang.{Long => JLong}
-import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
@@ -60,7 +59,6 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
- override def foldable: Boolean = child.foldable
override def nullable: Boolean = true
override def toString: String = s"$name($child)"
@@ -224,7 +222,7 @@ case class Bin(child: Expression)
def funcName: String = name.toLowerCase
- override def eval(input: catalyst.InternalRow): Any = {
+ override def eval(input: InternalRow): Any = {
val evalE = child.eval(input)
if (evalE == null) {
null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 9cacdceb13..6f56a9ec7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
@@ -113,7 +112,8 @@ case class Alias(child: Expression, name: String)(
extends NamedExpression with trees.UnaryNode[Expression] {
// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
- override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator]
+ override lazy val resolved =
+ childrenResolved && checkInputDataTypes().isSuccess && !child.isInstanceOf[Generator]
override def eval(input: InternalRow): Any = child.eval(input)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index 98acaf23c4..5d5911403e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -17,33 +17,32 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.analysis.UnresolvedException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.DataType
case class Coalesce(children: Seq[Expression]) extends Expression {
/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
- override def nullable: Boolean = !children.exists(!_.nullable)
+ override def nullable: Boolean = children.forall(_.nullable)
// Coalesce is foldable if all children are foldable.
- override def foldable: Boolean = !children.exists(!_.foldable)
+ override def foldable: Boolean = children.forall(_.foldable)
- // Only resolved if all the children are of the same type.
- override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children == Nil) {
+ TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty")
+ } else {
+ TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce")
+ }
+ }
override def toString: String = s"Coalesce(${children.mkString(",")})"
- override def dataType: DataType = if (resolved) {
- children.head.dataType
- } else {
- val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ")
- throw new UnresolvedException(
- this, s"Coalesce cannot have children of different types. $childTypes")
- }
+ override def dataType: DataType = children.head.dataType
override def eval(input: InternalRow): Any = {
- var i = 0
var result: Any = null
val childIterator = children.iterator
while (childIterator.hasNext && result == null) {
@@ -75,7 +74,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
- override def foldable: Boolean = child.foldable
override def nullable: Boolean = false
override def eval(input: InternalRow): Any = {
@@ -93,7 +91,6 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
}
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
- override def foldable: Boolean = child.foldable
override def nullable: Boolean = false
override def toString: String = s"IS NOT NULL $child"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index 30e41677b7..efc6f50b78 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -78,6 +78,8 @@ case class NewSet(elementType: DataType) extends LeafExpression {
/**
* Adds an item to a set.
* For performance, this expression mutates its input during evaluation.
+ * Note: this expression is internal and created only by the GeneratedAggregate,
+ * we don't need to do type check for it.
*/
case class AddItemToSet(item: Expression, set: Expression) extends Expression {
@@ -85,7 +87,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
override def nullable: Boolean = set.nullable
- override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT]
+ override def dataType: DataType = set.dataType
override def eval(input: InternalRow): Any = {
val itemEval = item.eval(input)
@@ -128,12 +130,14 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
/**
* Combines the elements of two sets.
* For performance, this expression mutates its left input set during evaluation.
+ * Note: this expression is internal and created only by the GeneratedAggregate,
+ * we don't need to do type check for it.
*/
case class CombineSets(left: Expression, right: Expression) extends BinaryExpression {
override def nullable: Boolean = left.nullable || right.nullable
- override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT]
+ override def dataType: DataType = left.dataType
override def symbol: String = "++="
@@ -176,6 +180,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
/**
* Returns the number of elements in the input set.
+ * Note: this expression is internal and created only by the GeneratedAggregate,
+ * we don't need to do type check for it.
*/
case class CountSet(child: Expression) extends UnaryExpression {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 315c63e63c..44416e79cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -117,8 +117,6 @@ trait CaseConversionExpression extends ExpectsInputTypes {
def convert(v: UTF8String): UTF8String
- override def foldable: Boolean = child.foldable
- override def nullable: Boolean = child.nullable
override def dataType: DataType = StringType
override def expectedChildTypes: Seq[DataType] = Seq(StringType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index 896e383f50..12023ad311 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -68,7 +68,8 @@ case class WindowSpecDefinition(
override def children: Seq[Expression] = partitionSpec ++ orderSpec
override lazy val resolved: Boolean =
- childrenResolved && frameSpecification.isInstanceOf[SpecifiedWindowFrame]
+ childrenResolved && checkInputDataTypes().isSuccess &&
+ frameSpecification.isInstanceOf[SpecifiedWindowFrame]
override def toString: String = simpleString
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 04857a23f4..8656cc334d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -48,6 +48,15 @@ object TypeUtils {
}
}
+ def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = {
+ if (types.distinct.size > 1) {
+ TypeCheckResult.TypeCheckFailure(
+ s"input to $caller should all be the same type, but it's ${types.mkString("[", ", ", "]")}")
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
def getNumeric(t: DataType): Numeric[Any] =
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index e09cd790a7..77ca080f36 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -193,7 +193,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
errorTest(
"bad casts",
testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
- "invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
+ "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
errorTest(
"non-boolean filters",
@@ -264,9 +264,9 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
val plan =
Aggregate(
Nil,
- Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil,
+ Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil,
LocalRelation(
- AttributeReference("a", StringType)(exprId = ExprId(2))))
+ AttributeReference("a", IntegerType)(exprId = ExprId(2))))
assert(plan.resolved)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 49b1119897..bc1537b071 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -15,13 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.expressions
+package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.StringType
@@ -136,6 +136,28 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(
CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)),
"WHEN expressions in CaseWhen should all be boolean type")
+ }
+
+ test("check types for aggregates") {
+ // We will cast String to Double for sum and average
+ assertSuccess(Sum('stringField))
+ assertSuccess(SumDistinct('stringField))
+ assertSuccess(Average('stringField))
+
+ assertError(Min('complexField), "function min accepts non-complex type")
+ assertError(Max('complexField), "function max accepts non-complex type")
+ assertError(Sum('booleanField), "function sum accepts numeric type")
+ assertError(SumDistinct('booleanField), "function sumDistinct accepts numeric type")
+ assertError(Average('booleanField), "function average accepts numeric type")
+ }
+ test("check types for others") {
+ assertError(CreateArray(Seq('intField, 'booleanField)),
+ "input to function array should all be the same type")
+ assertError(Coalesce(Seq('intField, 'booleanField)),
+ "input to function coalesce should all be the same type")
+ assertError(Coalesce(Nil), "input to function coalesce cannot be empty")
+ assertError(Explode('intField),
+ "input to function explode should be array or map type")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 6db551c543..f9c3fe92c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -55,7 +55,7 @@ private[spark] case class PythonUDF(
override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
- def nullable: Boolean = true
+ override def nullable: Boolean = true
override def eval(input: InternalRow): Any = {
throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
index f0f04f8c73..197e9bfb02 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
@@ -59,10 +59,4 @@ class HiveTypeCoercionSuite extends HiveComparisonTest {
}
assert(numEquals === 1)
}
-
- test("COALESCE with different types") {
- intercept[RuntimeException] {
- TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect()
- }
- }
}