aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-07-18 18:18:19 -0700
committerReynold Xin <rxin@databricks.com>2015-07-18 18:18:19 -0700
commit9914b1b2c5d5fe020f54d95f59f03023de2ea78a (patch)
tree8597a82ec83149e252b9ae5ed345d013d851f22a
parente16a19a39ed3369dffd375d712066d12add71c9e (diff)
downloadspark-9914b1b2c5d5fe020f54d95f59f03023de2ea78a.tar.gz
spark-9914b1b2c5d5fe020f54d95f59f03023de2ea78a.tar.bz2
spark-9914b1b2c5d5fe020f54d95f59f03023de2ea78a.zip
[SPARK-9150][SQL] Create CodegenFallback and Unevaluable trait
It is very hard to track which expressions have code gen implemented or not. This patch removes the default fallback gencode implementation from Expression, and moves that into a new trait called CodegenFallback. Each concrete expression needs to either implement code generation, or mix in CodegenFallback. This makes it very easy to track which expressions have code generation implemented already. Additionally, this patch creates an Unevaluable trait that can be used to track expressions that don't support evaluation (e.g. Star). Author: Reynold Xin <rxin@databricks.com> Closes #7487 from rxin/codegenfallback and squashes the following commits: 14ebf38 [Reynold Xin] Fixed Conv 6c1c882 [Reynold Xin] Fixed Alias. b42611b [Reynold Xin] [SPARK-9150][SQL] Create a trait to track code generation for expressions. cb5c066 [Reynold Xin] Removed extra import. 39cbe40 [Reynold Xin] [SPARK-8240][SQL] string function: concat
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala28
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala40
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala61
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala48
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala6
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala25
24 files changed, 206 insertions, 218 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 4a1a1ed61e..0daee1990a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -17,9 +17,8 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.{errors, trees}
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.errors
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.catalyst.trees.TreeNode
@@ -50,7 +49,7 @@ case class UnresolvedRelation(
/**
* Holds the name of an attribute that has yet to be resolved.
*/
-case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute {
+case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Unevaluable {
def name: String =
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
@@ -66,10 +65,6 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute {
override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
- // Unresolved attributes are transient at compile time and don't get evaluated during execution.
- override def eval(input: InternalRow = null): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
-
override def toString: String = s"'$name"
}
@@ -78,16 +73,14 @@ object UnresolvedAttribute {
def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name))
}
-case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression {
+case class UnresolvedFunction(name: String, children: Seq[Expression])
+ extends Expression with Unevaluable {
+
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
- // Unresolved functions are transient at compile time and don't get evaluated during execution.
- override def eval(input: InternalRow = null): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
-
override def toString: String = s"'$name(${children.mkString(",")})"
}
@@ -105,10 +98,6 @@ abstract class Star extends LeafExpression with NamedExpression {
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
override lazy val resolved = false
- // Star gets expanded at runtime so we never evaluate a Star.
- override def eval(input: InternalRow = null): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
-
def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression]
}
@@ -120,7 +109,7 @@ abstract class Star extends LeafExpression with NamedExpression {
* @param table an optional table that should be the target of the expansion. If omitted all
* tables' columns are produced.
*/
-case class UnresolvedStar(table: Option[String]) extends Star {
+case class UnresolvedStar(table: Option[String]) extends Star with Unevaluable {
override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = {
val expandedAttributes: Seq[Attribute] = table match {
@@ -149,7 +138,7 @@ case class UnresolvedStar(table: Option[String]) extends Star {
* @param names the names to be associated with each output of computing [[child]].
*/
case class MultiAlias(child: Expression, names: Seq[String])
- extends UnaryExpression with NamedExpression {
+ extends UnaryExpression with NamedExpression with CodegenFallback {
override def name: String = throw new UnresolvedException(this, "name")
@@ -165,9 +154,6 @@ case class MultiAlias(child: Expression, names: Seq[String])
override lazy val resolved = false
- override def eval(input: InternalRow = null): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
-
override def toString: String = s"$child AS $names"
}
@@ -178,7 +164,7 @@ case class MultiAlias(child: Expression, names: Seq[String])
*
* @param expressions Expressions to expand.
*/
-case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
+case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable {
override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions
override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")")
}
@@ -192,23 +178,21 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
* can be key of Map, index of Array, field name of Struct.
*/
case class UnresolvedExtractValue(child: Expression, extraction: Expression)
- extends UnaryExpression {
+ extends UnaryExpression with Unevaluable {
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
- override def eval(input: InternalRow = null): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
-
override def toString: String = s"$child[$extraction]"
}
/**
* Holds the expression that has yet to be aliased.
*/
-case class UnresolvedAlias(child: Expression) extends UnaryExpression with NamedExpression {
+case class UnresolvedAlias(child: Expression)
+ extends UnaryExpression with NamedExpression with Unevaluable {
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
@@ -218,7 +202,4 @@ case class UnresolvedAlias(child: Expression) extends UnaryExpression with Named
override def name: String = throw new UnresolvedException(this, "name")
override lazy val resolved = false
-
- override def eval(input: InternalRow = null): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 692b9fddbb..3346d3c9f9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -18,12 +18,10 @@
package org.apache.spark.sql.catalyst.expressions
import java.math.{BigDecimal => JavaBigDecimal}
-import java.sql.{Date, Timestamp}
-import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{Interval, UTF8String}
@@ -106,7 +104,8 @@ object Cast {
}
/** Cast the child expression to the target data type. */
-case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
+case class Cast(child: Expression, dataType: DataType)
+ extends UnaryExpression with CodegenFallback {
override def checkInputDataTypes(): TypeCheckResult = {
if (Cast.canCast(child.dataType, dataType)) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 0e128d8bdc..d0a1aa9a1e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -101,19 +101,7 @@ abstract class Expression extends TreeNode[Expression] {
* @param ev an [[GeneratedExpressionCode]] with unique terms.
* @return Java source code
*/
- protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- ctx.references += this
- val objectTerm = ctx.freshName("obj")
- s"""
- /* expression: ${this} */
- Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i);
- boolean ${ev.isNull} = $objectTerm == null;
- ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)};
- if (!${ev.isNull}) {
- ${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm;
- }
- """
- }
+ protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String
/**
* Returns `true` if this expression and all its children have been resolved to a specific schema
@@ -183,6 +171,20 @@ abstract class Expression extends TreeNode[Expression] {
/**
+ * An expression that cannot be evaluated. Some expressions don't live past analysis or optimization
+ * time (e.g. Star). This trait is used by those expressions.
+ */
+trait Unevaluable { self: Expression =>
+
+ override def eval(input: InternalRow = null): Any =
+ throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
+ throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+}
+
+
+/**
* A leaf expression, i.e. one without any child expressions.
*/
abstract class LeafExpression extends Expression {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 22687acd68..11c7950c06 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types.DataType
/**
@@ -29,7 +30,8 @@ case class ScalaUDF(
function: AnyRef,
dataType: DataType,
children: Seq[Expression],
- inputTypes: Seq[DataType] = Nil) extends Expression with ImplicitCastInputTypes {
+ inputTypes: Seq[DataType] = Nil)
+ extends Expression with ImplicitCastInputTypes with CodegenFallback {
override def nullable: Boolean = true
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index b8f7068c9e..3f436c0eb8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -17,9 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types.DataType
abstract sealed class SortDirection
@@ -30,7 +27,8 @@ case object Descending extends SortDirection
* An expression that can be used to sort a tuple. This class extends expression primarily so that
* transformations over expression will descend into its child.
*/
-case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression {
+case class SortOrder(child: Expression, direction: SortDirection)
+ extends UnaryExpression with Unevaluable {
/** Sort order is not foldable because we don't have an eval for it. */
override def foldable: Boolean = false
@@ -38,9 +36,5 @@ case class SortOrder(child: Expression, direction: SortDirection) extends UnaryE
override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable
- // SortOrder itself is never evaluated.
- override def eval(input: InternalRow = null): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
-
override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index af9a674ab4..d705a12860 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
-trait AggregateExpression extends Expression {
+
+trait AggregateExpression extends Expression with Unevaluable {
/**
* Aggregate expressions should not be foldable.
@@ -38,13 +39,6 @@ trait AggregateExpression extends Expression {
* of input rows/
*/
def newInstance(): AggregateFunction
-
- /**
- * [[AggregateExpression.eval]] should never be invoked because [[AggregateExpression]]'s are
- * replaced with a physical aggregate operator at runtime.
- */
- override def eval(input: InternalRow = null): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index e83650fc8c..05b5ad88fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.Interval
@@ -65,7 +65,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects
/**
* A function that get the absolute value of the numeric value.
*/
-case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class Abs(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with CodegenFallback {
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
new file mode 100644
index 0000000000..bf4f600cb2
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+
+/**
+ * A trait that can be used to provide a fallback mode for expression code generation.
+ */
+trait CodegenFallback { self: Expression =>
+
+ protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ ctx.references += this
+ val objectTerm = ctx.freshName("obj")
+ s"""
+ /* expression: ${this} */
+ Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i);
+ boolean ${ev.isNull} = $objectTerm == null;
+ ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)};
+ if (!${ev.isNull}) {
+ ${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm;
+ }
+ """
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index d1e4c45886..f9fd04c02a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
/**
* Returns an Array containing the evaluation of all children expressions.
*/
-case class CreateArray(children: Seq[Expression]) extends Expression {
+case class CreateArray(children: Seq[Expression]) extends Expression with CodegenFallback {
override def foldable: Boolean = children.forall(_.foldable)
@@ -51,7 +52,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
* Returns a Row containing the evaluation of all children expressions.
* TODO: [[CreateStruct]] does not support codegen.
*/
-case class CreateStruct(children: Seq[Expression]) extends Expression {
+case class CreateStruct(children: Seq[Expression]) extends Expression with CodegenFallback {
override def foldable: Boolean = children.forall(_.foldable)
@@ -83,7 +84,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
-case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
+case class CreateNamedStruct(children: Seq[Expression]) extends Expression with CodegenFallback {
private lazy val (nameExprs, valExprs) =
children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
@@ -103,11 +104,11 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
override def checkInputDataTypes(): TypeCheckResult = {
if (children.size % 2 != 0) {
- TypeCheckResult.TypeCheckFailure("CreateNamedStruct expects an even number of arguments.")
+ TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
} else {
val invalidNames =
nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable)
- if (invalidNames.size != 0) {
+ if (invalidNames.nonEmpty) {
TypeCheckResult.TypeCheckFailure(
s"Odd position only allow foldable and not-null StringType expressions, got :" +
s" ${invalidNames.mkString(",")}")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala
index dd5ec330a7..4bed140cff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -27,7 +28,7 @@ import org.apache.spark.sql.types._
*
* There is no code generation since this expression should get constant folded by the optimizer.
*/
-case class CurrentDate() extends LeafExpression {
+case class CurrentDate() extends LeafExpression with CodegenFallback {
override def foldable: Boolean = true
override def nullable: Boolean = false
@@ -44,7 +45,7 @@ case class CurrentDate() extends LeafExpression {
*
* There is no code generation since this expression should get constant folded by the optimizer.
*/
-case class CurrentTimestamp() extends LeafExpression {
+case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
override def foldable: Boolean = true
override def nullable: Boolean = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index c58a6d3614..2dbcf2830f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types._
/**
@@ -73,7 +73,7 @@ case class UserDefinedGenerator(
elementTypes: Seq[(DataType, Boolean)],
function: Row => TraversableOnce[InternalRow],
children: Seq[Expression])
- extends Generator {
+ extends Generator with CodegenFallback {
@transient private[this] var inputRow: InterpretedProjection = _
@transient private[this] var convertToScala: (InternalRow) => Row = _
@@ -100,7 +100,7 @@ case class UserDefinedGenerator(
/**
* Given an input array produces a sequence of rows for each value in the array.
*/
-case class Explode(child: Expression) extends UnaryExpression with Generator {
+case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
override def children: Seq[Expression] = child :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index e1fdb29541..f25ac32679 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types._
@@ -75,7 +75,8 @@ object IntegerLiteral {
/**
* In order to do type checking, use Literal.create() instead of constructor
*/
-case class Literal protected (value: Any, dataType: DataType) extends LeafExpression {
+case class Literal protected (value: Any, dataType: DataType)
+ extends LeafExpression with CodegenFallback {
override def foldable: Boolean = true
override def nullable: Boolean = value == null
@@ -142,7 +143,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
// TODO: Specialize
case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true)
- extends LeafExpression {
+ extends LeafExpression with CodegenFallback {
def update(expression: Expression, input: InternalRow): Unit = {
value = expression.eval(input)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index eb5c065a34..7ce64d29ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.expressions
import java.{lang => jl}
-import java.util.Arrays
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure}
@@ -29,11 +28,14 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* A leaf expression specifically for math constants. Math constants expect no input.
+ *
+ * There is no code generation because they should get constant folded by the optimizer.
+ *
* @param c The math constant.
* @param name The short name of the function
*/
abstract class LeafMathExpression(c: Double, name: String)
- extends LeafExpression with Serializable {
+ extends LeafExpression with CodegenFallback {
override def dataType: DataType = DoubleType
override def foldable: Boolean = true
@@ -41,13 +43,6 @@ abstract class LeafMathExpression(c: Double, name: String)
override def toString: String = s"$name()"
override def eval(input: InternalRow): Any = c
-
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- s"""
- boolean ${ev.isNull} = false;
- ${ctx.javaType(dataType)} ${ev.primitive} = java.lang.Math.$name;
- """
- }
}
/**
@@ -130,8 +125,16 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
+/**
+ * Euler's number. Note that there is no code generation because this is only
+ * evaluated by the optimizer during constant folding.
+ */
case class EulerNumber() extends LeafMathExpression(math.E, "E")
+/**
+ * Pi. Note that there is no code generation because this is only
+ * evaluated by the optimizer during constant folding.
+ */
case class Pi() extends LeafMathExpression(math.Pi, "PI")
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -161,7 +164,7 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH"
* @param toBaseExpr to which base
*/
case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
- extends Expression with ImplicitCastInputTypes{
+ extends Expression with ImplicitCastInputTypes with CodegenFallback {
override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable
@@ -171,6 +174,8 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType)
+ override def dataType: DataType = StringType
+
/** Returns the result of evaluating this expression on a given input Row */
override def eval(input: InternalRow): Any = {
val num = numExpr.eval(input)
@@ -179,17 +184,13 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
if (num == null || fromBase == null || toBase == null) {
null
} else {
- conv(num.asInstanceOf[UTF8String].getBytes,
- fromBase.asInstanceOf[Int], toBase.asInstanceOf[Int])
+ conv(
+ num.asInstanceOf[UTF8String].getBytes,
+ fromBase.asInstanceOf[Int],
+ toBase.asInstanceOf[Int])
}
}
- /**
- * Returns the [[DataType]] of the result of evaluating this expression. It is
- * invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false).
- */
- override def dataType: DataType = StringType
-
private val value = new Array[Byte](64)
/**
@@ -208,7 +209,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
// Two's complement => x = uval - 2*MAX - 2
// => uval = x + 2*MAX + 2
// Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c
- (x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m)
+ x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m
}
}
@@ -220,7 +221,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
*/
private def decode(v: Long, radix: Int): Unit = {
var tmpV = v
- Arrays.fill(value, 0.asInstanceOf[Byte])
+ java.util.Arrays.fill(value, 0.asInstanceOf[Byte])
var i = value.length - 1
while (tmpV != 0) {
val q = unsignedLongDiv(tmpV, radix)
@@ -254,7 +255,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
v = v * radix + value(i)
i += 1
}
- return v
+ v
}
/**
@@ -292,16 +293,16 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
* NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv
*/
private def conv(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = {
- if (n == null || fromBase == null || toBase == null || n.isEmpty) {
- return null
- }
-
if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX
|| Math.abs(toBase) < Character.MIN_RADIX
|| Math.abs(toBase) > Character.MAX_RADIX) {
return null
}
+ if (n.length == 0) {
+ return null
+ }
+
var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0)
// Copy the digits in the right side of the array
@@ -340,7 +341,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
resultStartPos = firstNonZeroPos - 1
value(resultStartPos) = '-'
}
- UTF8String.fromBytes( Arrays.copyOfRange(value, resultStartPos, value.length))
+ UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, resultStartPos, value.length))
}
}
@@ -495,8 +496,8 @@ object Hex {
* Otherwise if the number is a STRING, it converts each character into its hex representation
* and returns the resulting STRING. Negative numbers would be treated as two's complement.
*/
-case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
- // TODO: Create code-gen version.
+case class Hex(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, BinaryType, StringType))
@@ -539,8 +540,8 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput
* Performs the inverse operation of HEX.
* Resulting characters are returned as a byte array.
*/
-case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
- // TODO: Create code-gen version.
+case class Unhex(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback {
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index c083ac08de..6f173b52ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -19,9 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
-import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
object NamedExpression {
@@ -122,7 +120,9 @@ case class Alias(child: Expression, name: String)(
override def eval(input: InternalRow): Any = child.eval(input)
+ /** Just a simple passthrough for code generation. */
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = ""
override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable
@@ -177,7 +177,7 @@ case class AttributeReference(
override val metadata: Metadata = Metadata.empty)(
val exprId: ExprId = NamedExpression.newExprId,
val qualifiers: Seq[String] = Nil)
- extends Attribute {
+ extends Attribute with Unevaluable {
/**
* Returns true iff the expression id is the same for both attributes.
@@ -236,10 +236,6 @@ case class AttributeReference(
}
}
- // Unresolved attributes are transient at compile time and don't get evaluated during execution.
- override def eval(input: InternalRow = null): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
-
override def toString: String = s"$name#${exprId.id}$typeSuffix"
}
@@ -247,7 +243,7 @@ case class AttributeReference(
* A place holder used when printing expressions without debugging information such as the
* expression id or the unresolved indicator.
*/
-case class PrettyAttribute(name: String) extends Attribute {
+case class PrettyAttribute(name: String) extends Attribute with Unevaluable {
override def toString: String = name
@@ -259,7 +255,6 @@ case class PrettyAttribute(name: String) extends Attribute {
override def withName(newName: String): Attribute = throw new UnsupportedOperationException
override def qualifiers: Seq[String] = throw new UnsupportedOperationException
override def exprId: ExprId = throw new UnsupportedOperationException
- override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
override def nullable: Boolean = throw new UnsupportedOperationException
override def dataType: DataType = NullType
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index bddd2a9ecc..40ec3df224 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.TypeUtils
-import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._
+
object InterpretedPredicate {
def create(expression: Expression, inputSchema: Seq[Attribute]): (InternalRow => Boolean) =
create(BindReferences.bindReference(expression, inputSchema))
@@ -91,7 +92,7 @@ case class Not(child: Expression)
/**
* Evaluates to `true` if `list` contains `value`.
*/
-case class In(value: Expression, list: Seq[Expression]) extends Predicate {
+case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback {
override def children: Seq[Expression] = value +: list
override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
@@ -109,7 +110,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
* static.
*/
case class InSet(child: Expression, hset: Set[Any])
- extends UnaryExpression with Predicate {
+ extends UnaryExpression with Predicate with CodegenFallback {
override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index 49b2026364..5b0fe8dfe2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
@@ -52,7 +52,7 @@ private[sql] class OpenHashSetUDT(
/**
* Creates a new set of the specified type
*/
-case class NewSet(elementType: DataType) extends LeafExpression {
+case class NewSet(elementType: DataType) extends LeafExpression with CodegenFallback {
override def nullable: Boolean = false
@@ -82,7 +82,8 @@ case class NewSet(elementType: DataType) extends LeafExpression {
* Note: this expression is internal and created only by the GeneratedAggregate,
* we don't need to do type check for it.
*/
-case class AddItemToSet(item: Expression, set: Expression) extends Expression {
+case class AddItemToSet(item: Expression, set: Expression)
+ extends Expression with CodegenFallback {
override def children: Seq[Expression] = item :: set :: Nil
@@ -134,7 +135,8 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
* Note: this expression is internal and created only by the GeneratedAggregate,
* we don't need to do type check for it.
*/
-case class CombineSets(left: Expression, right: Expression) extends BinaryExpression {
+case class CombineSets(left: Expression, right: Expression)
+ extends BinaryExpression with CodegenFallback {
override def nullable: Boolean = left.nullable
override def dataType: DataType = left.dataType
@@ -181,7 +183,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
* Note: this expression is internal and created only by the GeneratedAggregate,
* we don't need to do type check for it.
*/
-case class CountSet(child: Expression) extends UnaryExpression {
+case class CountSet(child: Expression) extends UnaryExpression with CodegenFallback {
override def dataType: DataType = LongType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index b36354eff0..560b1bc2d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -103,7 +103,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes {
* Simple RegEx pattern matching function
*/
case class Like(left: Expression, right: Expression)
- extends BinaryExpression with StringRegexExpression {
+ extends BinaryExpression with StringRegexExpression with CodegenFallback {
// replace the _ with .{1} exactly match 1 time of any character
// replace the % with .*, match 0 or more times with any character
@@ -133,14 +133,16 @@ case class Like(left: Expression, right: Expression)
override def toString: String = s"$left LIKE $right"
}
+
case class RLike(left: Expression, right: Expression)
- extends BinaryExpression with StringRegexExpression {
+ extends BinaryExpression with StringRegexExpression with CodegenFallback {
override def escape(v: String): String = v
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
override def toString: String = s"$left RLIKE $right"
}
+
trait String2StringExpression extends ImplicitCastInputTypes {
self: UnaryExpression =>
@@ -156,7 +158,8 @@ trait String2StringExpression extends ImplicitCastInputTypes {
/**
* A function that converts the characters of a string to uppercase.
*/
-case class Upper(child: Expression) extends UnaryExpression with String2StringExpression {
+case class Upper(child: Expression)
+ extends UnaryExpression with String2StringExpression {
override def convert(v: UTF8String): UTF8String = v.toUpperCase
@@ -301,7 +304,7 @@ case class StringInstr(str: Expression, substr: Expression)
* in given string after position pos.
*/
case class StringLocate(substr: Expression, str: Expression, start: Expression)
- extends Expression with ImplicitCastInputTypes {
+ extends Expression with ImplicitCastInputTypes with CodegenFallback {
def this(substr: Expression, str: Expression) = {
this(substr, str, Literal(0))
@@ -342,7 +345,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
* Returns str, left-padded with pad to a length of len.
*/
case class StringLPad(str: Expression, len: Expression, pad: Expression)
- extends Expression with ImplicitCastInputTypes {
+ extends Expression with ImplicitCastInputTypes with CodegenFallback {
override def children: Seq[Expression] = str :: len :: pad :: Nil
override def foldable: Boolean = children.forall(_.foldable)
@@ -380,7 +383,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression)
* Returns str, right-padded with pad to a length of len.
*/
case class StringRPad(str: Expression, len: Expression, pad: Expression)
- extends Expression with ImplicitCastInputTypes {
+ extends Expression with ImplicitCastInputTypes with CodegenFallback {
override def children: Seq[Expression] = str :: len :: pad :: Nil
override def foldable: Boolean = children.forall(_.foldable)
@@ -417,9 +420,9 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
/**
* Returns the input formatted according do printf-style format strings
*/
-case class StringFormat(children: Expression*) extends Expression {
+case class StringFormat(children: Expression*) extends Expression with CodegenFallback {
- require(children.length >=1, "printf() should take at least 1 argument")
+ require(children.nonEmpty, "printf() should take at least 1 argument")
override def foldable: Boolean = children.forall(_.foldable)
override def nullable: Boolean = children(0).nullable
@@ -436,7 +439,7 @@ case class StringFormat(children: Expression*) extends Expression {
val formatter = new java.util.Formatter(sb, Locale.US)
val arglist = args.map(_.eval(input).asInstanceOf[AnyRef])
- formatter.format(pattern.asInstanceOf[UTF8String].toString(), arglist: _*)
+ formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*)
UTF8String.fromString(sb.toString)
}
@@ -483,7 +486,8 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2
/**
* Returns a n spaces string.
*/
-case class StringSpace(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class StringSpace(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(IntegerType)
@@ -503,7 +507,7 @@ case class StringSpace(child: Expression) extends UnaryExpression with ImplicitC
* Splits str around pat (pattern is a regular expression).
*/
case class StringSplit(str: Expression, pattern: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with CodegenFallback {
override def left: Expression = str
override def right: Expression = pattern
@@ -524,7 +528,7 @@ case class StringSplit(str: Expression, pattern: Expression)
* Defined for String and Binary types.
*/
case class Substring(str: Expression, pos: Expression, len: Expression)
- extends Expression with ImplicitCastInputTypes {
+ extends Expression with ImplicitCastInputTypes with CodegenFallback {
def this(str: Expression, pos: Expression) = {
this(str, pos, Literal(Integer.MAX_VALUE))
@@ -606,8 +610,6 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy
case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length")
}
}
-
- override def prettyName: String = "length"
}
/**
@@ -632,7 +634,9 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
/**
* Returns the numeric value of the first character of str.
*/
-case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Ascii(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback {
+
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType)
@@ -649,7 +653,9 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp
/**
* Converts the argument from binary to a base 64 string.
*/
-case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Base64(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback {
+
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType)
@@ -663,7 +669,9 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn
/**
* Converts the argument from a base 64 string to BINARY.
*/
-case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class UnBase64(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback {
+
override def dataType: DataType = BinaryType
override def inputTypes: Seq[DataType] = Seq(StringType)
@@ -677,7 +685,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast
* If either argument is null, the result will also be null.
*/
case class Decode(bin: Expression, charset: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with CodegenFallback {
override def left: Expression = bin
override def right: Expression = charset
@@ -696,7 +704,7 @@ case class Decode(bin: Expression, charset: Expression)
* If either argument is null, the result will also be null.
*/
case class Encode(value: Expression, charset: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with CodegenFallback {
override def left: Expression = value
override def right: Expression = charset
@@ -715,7 +723,7 @@ case class Encode(value: Expression, charset: Expression)
* fractional part.
*/
case class FormatNumber(x: Expression, d: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
override def left: Expression = x
override def right: Expression = d
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index c8aa571df6..50bbfd644d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.types.{DataType, NumericType}
/**
@@ -37,7 +36,7 @@ sealed trait WindowSpec
case class WindowSpecDefinition(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
- frameSpecification: WindowFrame) extends Expression with WindowSpec {
+ frameSpecification: WindowFrame) extends Expression with WindowSpec with Unevaluable {
def validate: Option[String] = frameSpecification match {
case UnspecifiedFrame =>
@@ -75,7 +74,6 @@ case class WindowSpecDefinition(
override def toString: String = simpleString
- override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
override def nullable: Boolean = true
override def foldable: Boolean = false
override def dataType: DataType = throw new UnsupportedOperationException
@@ -274,60 +272,43 @@ trait WindowFunction extends Expression {
case class UnresolvedWindowFunction(
name: String,
children: Seq[Expression])
- extends Expression with WindowFunction {
+ extends Expression with WindowFunction with Unevaluable {
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
- override def init(): Unit =
- throw new UnresolvedException(this, "init")
- override def reset(): Unit =
- throw new UnresolvedException(this, "reset")
+ override def init(): Unit = throw new UnresolvedException(this, "init")
+ override def reset(): Unit = throw new UnresolvedException(this, "reset")
override def prepareInputParameters(input: InternalRow): AnyRef =
throw new UnresolvedException(this, "prepareInputParameters")
- override def update(input: AnyRef): Unit =
- throw new UnresolvedException(this, "update")
+ override def update(input: AnyRef): Unit = throw new UnresolvedException(this, "update")
override def batchUpdate(inputs: Array[AnyRef]): Unit =
throw new UnresolvedException(this, "batchUpdate")
- override def evaluate(): Unit =
- throw new UnresolvedException(this, "evaluate")
- override def get(index: Int): Any =
- throw new UnresolvedException(this, "get")
- // Unresolved functions are transient at compile time and don't get evaluated during execution.
- override def eval(input: InternalRow = null): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+ override def evaluate(): Unit = throw new UnresolvedException(this, "evaluate")
+ override def get(index: Int): Any = throw new UnresolvedException(this, "get")
override def toString: String = s"'$name(${children.mkString(",")})"
- override def newInstance(): WindowFunction =
- throw new UnresolvedException(this, "newInstance")
+ override def newInstance(): WindowFunction = throw new UnresolvedException(this, "newInstance")
}
case class UnresolvedWindowExpression(
child: UnresolvedWindowFunction,
- windowSpec: WindowSpecReference) extends UnaryExpression {
+ windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable {
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
-
- // Unresolved functions are transient at compile time and don't get evaluated during execution.
- override def eval(input: InternalRow = null): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
case class WindowExpression(
windowFunction: WindowFunction,
- windowSpec: WindowSpecDefinition) extends Expression {
-
- override def children: Seq[Expression] =
- windowFunction :: windowSpec :: Nil
+ windowSpec: WindowSpecDefinition) extends Expression with Unevaluable {
- override def eval(input: InternalRow): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+ override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil
override def dataType: DataType = windowFunction.dataType
override def foldable: Boolean = windowFunction.foldable
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 42dead7c28..2dcfa19fec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -17,9 +17,7 @@
package org.apache.spark.sql.catalyst.plans.physical
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Unevaluable, Expression, SortOrder}
import org.apache.spark.sql.types.{DataType, IntegerType}
/**
@@ -146,8 +144,7 @@ case object BroadcastPartitioning extends Partitioning {
* in the same partition.
*/
case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
- extends Expression
- with Partitioning {
+ extends Expression with Partitioning with Unevaluable {
override def children: Seq[Expression] = expressions
override def nullable: Boolean = false
@@ -169,9 +166,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
}
override def keyExpressions: Seq[Expression] = expressions
-
- override def eval(input: InternalRow = null): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
/**
@@ -187,8 +181,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
* into its child.
*/
case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
- extends Expression
- with Partitioning {
+ extends Expression with Partitioning with Unevaluable {
override def children: Seq[SortOrder] = ordering
override def nullable: Boolean = false
@@ -213,7 +206,4 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
}
override def keyExpressions: Seq[Expression] = ordering.map(_.child)
-
- override def eval(input: InternalRow): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 2147d07e09..dca8c881f2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -31,9 +31,9 @@ import org.apache.spark.sql.types._
case class TestFunction(
children: Seq[Expression],
- inputTypes: Seq[AbstractDataType]) extends Expression with ImplicitCastInputTypes {
+ inputTypes: Seq[AbstractDataType])
+ extends Expression with ImplicitCastInputTypes with Unevaluable {
override def nullable: Boolean = true
- override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
override def dataType: DataType = StringType
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index c9b3c69c6d..f9442bccc4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -363,26 +363,26 @@ class HiveTypeCoercionSuite extends PlanTest {
object HiveTypeCoercionSuite {
case class AnyTypeUnaryExpression(child: Expression)
- extends UnaryExpression with ExpectsInputTypes {
+ extends UnaryExpression with ExpectsInputTypes with Unevaluable {
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
override def dataType: DataType = NullType
}
case class NumericTypeUnaryExpression(child: Expression)
- extends UnaryExpression with ExpectsInputTypes {
+ extends UnaryExpression with ExpectsInputTypes with Unevaluable {
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def dataType: DataType = NullType
}
case class AnyTypeBinaryOperator(left: Expression, right: Expression)
- extends BinaryOperator {
+ extends BinaryOperator with Unevaluable {
override def dataType: DataType = NullType
override def inputType: AbstractDataType = AnyDataType
override def symbol: String = "anytype"
}
case class NumericTypeBinaryOperator(left: Expression, right: Expression)
- extends BinaryOperator {
+ extends BinaryOperator with Unevaluable {
override def dataType: DataType = NullType
override def inputType: AbstractDataType = NumericType
override def symbol: String = "numerictype"
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 1bd7d4e5cd..8fff39906b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -22,9 +22,10 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types.{IntegerType, StringType, NullType}
-case class Dummy(optKey: Option[Expression]) extends Expression {
+case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback {
override def children: Seq[Expression] = optKey.toSeq
override def nullable: Boolean = true
override def dataType: NullType = NullType
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index 6d6e67dace..e6e27a87c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -51,15 +51,11 @@ private[spark] case class PythonUDF(
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]],
dataType: DataType,
- children: Seq[Expression]) extends Expression with SparkLogging {
+ children: Seq[Expression]) extends Expression with Unevaluable with SparkLogging {
override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.")
- }
}
/**
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 0bc8adb16a..4d23c7035c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -36,8 +36,8 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.hive.HiveShim._
@@ -81,7 +81,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
}
private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
- extends Expression with HiveInspectors with Logging {
+ extends Expression with HiveInspectors with CodegenFallback with Logging {
type UDFType = UDF
@@ -146,7 +146,7 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector)
}
private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
- extends Expression with HiveInspectors with Logging {
+ extends Expression with HiveInspectors with CodegenFallback with Logging {
type UDFType = GenericUDF
override def deterministic: Boolean = isUDFDeterministic
@@ -166,8 +166,8 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr
@transient
protected lazy val isUDFDeterministic = {
- val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
- (udfType != null && udfType.deterministic())
+ val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
+ udfType != null && udfType.deterministic()
}
override def foldable: Boolean =
@@ -301,7 +301,7 @@ private[hive] case class HiveWindowFunction(
pivotResult: Boolean,
isUDAFBridgeRequired: Boolean,
children: Seq[Expression]) extends WindowFunction
- with HiveInspectors {
+ with HiveInspectors with Unevaluable {
// Hive window functions are based on GenericUDAFResolver2.
type UDFType = GenericUDAFResolver2
@@ -330,7 +330,7 @@ private[hive] case class HiveWindowFunction(
evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors)
}
- def dataType: DataType =
+ override def dataType: DataType =
if (!pivotResult) {
inspectorToDataType(returnInspector)
} else {
@@ -344,10 +344,7 @@ private[hive] case class HiveWindowFunction(
}
}
- def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+ override def nullable: Boolean = true
@transient
lazy val inputProjection = new InterpretedProjection(children)
@@ -406,7 +403,7 @@ private[hive] case class HiveWindowFunction(
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
- override def newInstance: WindowFunction =
+ override def newInstance(): WindowFunction =
new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children)
}
@@ -476,7 +473,7 @@ private[hive] case class HiveUDAF(
/**
* Converts a Hive Generic User Defined Table Generating Function (UDTF) to a
- * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow
+ * [[Generator]]. Note that the semantics of Generators do not allow
* Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning
* dependent operations like calls to `close()` before producing output will not operate the same as
* in Hive. However, in practice this should not affect compatibility for most sane UDTFs
@@ -488,7 +485,7 @@ private[hive] case class HiveUDAF(
private[hive] case class HiveGenericUDTF(
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression])
- extends Generator with HiveInspectors {
+ extends Generator with HiveInspectors with CodegenFallback {
@transient
protected lazy val function: GenericUDTF = {