aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@databricks.com>2016-11-29 20:05:15 -0800
committerReynold Xin <rxin@databricks.com>2016-11-29 20:05:15 -0800
commitaf9789a4f5d00b3141f102e9f0ca52217e26c082 (patch)
tree99b5e53726166451ccd1a65ced9e010d943e7d2a /sql/catalyst/src
parent9b670bcaec9c220603ec10a6d186865dabf26a5b (diff)
downloadspark-af9789a4f5d00b3141f102e9f0ca52217e26c082.tar.gz
spark-af9789a4f5d00b3141f102e9f0ca52217e26c082.tar.bz2
spark-af9789a4f5d00b3141f102e9f0ca52217e26c082.zip
[SPARK-18632][SQL] AggregateFunction should not implement ImplicitCastInputTypes
## What changes were proposed in this pull request? `AggregateFunction` currently implements `ImplicitCastInputTypes` (which enables implicit input type casting). There are actually quite a few situations in which we don't need this, or require more control over our input. A recent example is the aggregate for `CountMinSketch` which should only take string, binary or integral types inputs. This PR removes `ImplicitCastInputTypes` from the `AggregateFunction` and makes a case-by-case decision on what kind of input validation we should use. ## How was this patch tested? Refactoring only. Existing tests. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #16066 from hvanhovell/SPARK-18632.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala26
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala26
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala2
18 files changed, 57 insertions, 48 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
index 692cbd7c0d..c2cd8951bc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -22,11 +22,11 @@ import java.nio.ByteBuffer
import com.google.common.primitives.{Doubles, Ints, Longs}
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.{InternalRow}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest}
+import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats}
@@ -71,7 +71,8 @@ case class ApproximatePercentile(
percentageExpression: Expression,
accuracyExpression: Expression,
override val mutableAggBufferOffset: Int,
- override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[PercentileDigest] {
+ override val inputAggBufferOffset: Int)
+ extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes {
def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = {
this(child, percentageExpression, accuracyExpression, 0, 0)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index d523420530..c423e17169 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.types._
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.")
-case class Average(child: Expression) extends DeclarativeAggregate {
+case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
override def prettyName: String = "avg"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 1a93f45903..572d29caf5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -42,7 +42,8 @@ import org.apache.spark.sql.types._
*
* @param child to compute central moments of.
*/
-abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate {
+abstract class CentralMomentAgg(child: Expression)
+ extends DeclarativeAggregate with ImplicitCastInputTypes {
/**
* The central moment order to be computed.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index 657f519d2a..95a4a0d5af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -32,7 +32,8 @@ import org.apache.spark.sql.types._
@ExpressionDescription(
usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.")
// scalastyle:on line.size.limit
-case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate {
+case class Corr(x: Expression, y: Expression)
+ extends DeclarativeAggregate with ImplicitCastInputTypes {
override def children: Seq[Expression] = Seq(x, y)
override def nullable: Boolean = true
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index bcae0dc075..1990f2f2f0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -38,9 +38,6 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
// Return data type.
override def dataType: DataType = LongType
- // Expected input data type.
- override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)
-
private lazy val count = AttributeReference("count", LongType, nullable = false)()
override lazy val aggBufferAttributes = count :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
index 1bfae9e5a4..f5f185f2c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
@@ -22,7 +22,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
-import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch
@@ -52,7 +52,8 @@ case class CountMinSketchAgg(
confidenceExpression: Expression,
seedExpression: Expression,
override val mutableAggBufferOffset: Int,
- override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[CountMinSketch] {
+ override val inputAggBufferOffset: Int)
+ extends TypedImperativeAggregate[CountMinSketch] with ExpectsInputTypes {
def this(
child: Expression,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
index ae5ed77970..fc6c34baaf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
@@ -25,7 +25,8 @@ import org.apache.spark.sql.types._
* Compute the covariance between two expressions.
* When applied on empty data (i.e., count is zero), it returns NULL.
*/
-abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggregate {
+abstract class Covariance(x: Expression, y: Expression)
+ extends DeclarativeAggregate with ImplicitCastInputTypes {
override def children: Seq[Expression] = Seq(x, y)
override def nullable: Boolean = true
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
index 29b8947980..bfc58c2288 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -33,16 +34,11 @@ import org.apache.spark.sql.types._
_FUNC_(expr[, isIgnoreNull]) - Returns the first value of `expr` for a group of rows.
If `isIgnoreNull` is true, returns only non-null values.
""")
-case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
+case class First(child: Expression, ignoreNullsExpr: Expression)
+ extends DeclarativeAggregate with ExpectsInputTypes {
def this(child: Expression) = this(child, Literal.create(false, BooleanType))
- private val ignoreNulls: Boolean = ignoreNullsExpr match {
- case Literal(b: Boolean, BooleanType) => b
- case _ =>
- throw new AnalysisException("The second argument of First should be a boolean literal.")
- }
-
override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil
override def nullable: Boolean = true
@@ -56,6 +52,20 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else if (!ignoreNullsExpr.foldable) {
+ TypeCheckFailure(
+ s"The second argument of First must be a boolean literal, but got: ${ignoreNullsExpr.sql}")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+
+ private def ignoreNulls: Boolean = ignoreNullsExpr.eval().asInstanceOf[Boolean]
+
private lazy val first = AttributeReference("first", child.dataType)()
private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
index 77b7eb228e..d5c9166443 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
@@ -140,8 +140,6 @@ case class HyperLogLogPlusPlus(
override def dataType: DataType = LongType
- override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
-
override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
/** Allocate enough words to store all registers. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
index b0a363e7d6..96a6ec08a1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -33,16 +34,11 @@ import org.apache.spark.sql.types._
_FUNC_(expr[, isIgnoreNull]) - Returns the last value of `expr` for a group of rows.
If `isIgnoreNull` is true, returns only non-null values.
""")
-case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
+case class Last(child: Expression, ignoreNullsExpr: Expression)
+ extends DeclarativeAggregate with ExpectsInputTypes {
def this(child: Expression) = this(child, Literal.create(false, BooleanType))
- private val ignoreNulls: Boolean = ignoreNullsExpr match {
- case Literal(b: Boolean, BooleanType) => b
- case _ =>
- throw new AnalysisException("The second argument of First should be a boolean literal.")
- }
-
override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil
override def nullable: Boolean = true
@@ -56,6 +52,20 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else if (!ignoreNullsExpr.foldable) {
+ TypeCheckFailure(
+ s"The second argument of Last must be a boolean literal, but got: ${ignoreNullsExpr.sql}")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+
+ private def ignoreNulls: Boolean = ignoreNullsExpr.eval().asInstanceOf[Boolean]
+
private lazy val last = AttributeReference("last", child.dataType)()
private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
index f32c9c677a..58fd1d8620 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
@@ -33,9 +33,6 @@ case class Max(child: Expression) extends DeclarativeAggregate {
// Return data type.
override def dataType: DataType = child.dataType
- // Expected input data type.
- override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
-
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForOrderingExpr(child.dataType, "function max")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
index 9ef42b9697..b2724ee768 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
@@ -33,9 +33,6 @@ case class Min(child: Expression) extends DeclarativeAggregate {
// Return data type.
override def dataType: DataType = child.dataType
- // Expected input data type.
- override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
-
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForOrderingExpr(child.dataType, "function min")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 356e088d1d..b51b55313e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -54,10 +54,11 @@ import org.apache.spark.util.collection.OpenHashMap
be between 0.0 and 1.0.
""")
case class Percentile(
- child: Expression,
- percentageExpression: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[Number, Long]] {
+ child: Expression,
+ percentageExpression: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes {
def this(child: Expression, percentageExpression: Expression) = {
this(child, percentageExpression, 0, 0)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
index 0876060772..9ad31243e4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
@@ -77,8 +77,6 @@ case class PivotFirst(
override val children: Seq[Expression] = pivotColumn :: valueColumn :: Nil
- override lazy val inputTypes: Seq[AbstractDataType] = children.map(_.dataType)
-
override val nullable: Boolean = false
val valueDataType = valueColumn.dataType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index f3731d4005..96e8ceec6d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.types._
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.")
-case class Sum(child: Expression) extends DeclarativeAggregate {
+case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
override def children: Seq[Expression] = child :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index d2880d58ae..b176e2a128 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -44,8 +44,6 @@ abstract class Collect extends ImperativeAggregate {
override def dataType: DataType = ArrayType(child.dataType)
- override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
-
override def supportsPartial: Boolean = false
override def aggBufferAttributes: Seq[AttributeReference] = Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index f3fd58bc98..7397b60360 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -155,7 +155,7 @@ case class AggregateExpression(
* Code which accepts [[AggregateFunction]] instances should be prepared to handle both types of
* aggregate functions.
*/
-sealed abstract class AggregateFunction extends Expression with ImplicitCastInputTypes {
+sealed abstract class AggregateFunction extends Expression {
/** An aggregate function is not foldable. */
final override def foldable: Boolean = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index 3cbbcdf4a9..c0d6a6b92b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -443,7 +443,6 @@ abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowF
abstract class RowNumberLike extends AggregateWindowFunction {
override def children: Seq[Expression] = Nil
- override def inputTypes: Seq[AbstractDataType] = Nil
protected val zero = Literal(0)
protected val one = Literal(1)
protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)()
@@ -600,7 +599,6 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow
* This documentation has been based upon similar documentation for the Hive and Presto projects.
*/
abstract class RankLike extends AggregateWindowFunction {
- override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType)
/** Store the values of the window 'order' expressions. */
protected val orderAttrs = children.map { expr =>