aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main
diff options
context:
space:
mode:
authorpetermaxlee <petermaxlee@gmail.com>2016-08-19 09:19:47 +0800
committerWenchen Fan <wenchen@databricks.com>2016-08-19 09:19:47 +0800
commitf5472dda51b980a726346587257c22873ff708e3 (patch)
tree39511907b8f69c02626af8603013a94461388281 /sql/catalyst/src/main
parentb72bb62d421840f82d663c6b8e3922bd14383fbb (diff)
downloadspark-f5472dda51b980a726346587257c22873ff708e3.tar.gz
spark-f5472dda51b980a726346587257c22873ff708e3.tar.bz2
spark-f5472dda51b980a726346587257c22873ff708e3.zip
[SPARK-16947][SQL] Support type coercion and foldable expression for inline tables
## What changes were proposed in this pull request? This patch improves inline table support with the following: 1. Support type coercion. 2. Support using foldable expressions. Previously only literals were supported. 3. Improve error message handling. 4. Improve test coverage. ## How was this patch tested? Added a new unit test suite ResolveInlineTablesSuite and a new file-based end-to-end test inline-table.sql. Author: petermaxlee <petermaxlee@gmail.com> Closes #14676 from petermaxlee/SPARK-16947.
Diffstat (limited to 'sql/catalyst/src/main')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala112
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala26
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala41
5 files changed, 150 insertions, 32 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 333dd4d9a4..41e0e6d65e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -108,6 +108,7 @@ class Analyzer(
GlobalAggregates ::
ResolveAggregateFunctions ::
TimeWindowing ::
+ ResolveInlineTables ::
TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
new file mode 100644
index 0000000000..7323197b10
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.{StructField, StructType}
+
+/**
+ * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]].
+ */
+object ResolveInlineTables extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case table: UnresolvedInlineTable if table.expressionsResolved =>
+ validateInputDimension(table)
+ validateInputEvaluable(table)
+ convert(table)
+ }
+
+ /**
+ * Validates the input data dimension:
+ * 1. All rows have the same cardinality.
+ * 2. The number of column aliases defined is consistent with the number of columns in data.
+ *
+ * This is package visible for unit testing.
+ */
+ private[analysis] def validateInputDimension(table: UnresolvedInlineTable): Unit = {
+ if (table.rows.nonEmpty) {
+ val numCols = table.names.size
+ table.rows.zipWithIndex.foreach { case (row, ri) =>
+ if (row.size != numCols) {
+ table.failAnalysis(s"expected $numCols columns but found ${row.size} columns in row $ri")
+ }
+ }
+ }
+ }
+
+ /**
+ * Validates that all inline table data are valid expressions that can be evaluated
+ * (in this they must be foldable).
+ *
+ * This is package visible for unit testing.
+ */
+ private[analysis] def validateInputEvaluable(table: UnresolvedInlineTable): Unit = {
+ table.rows.foreach { row =>
+ row.foreach { e =>
+ // Note that nondeterministic expressions are not supported since they are not foldable.
+ if (!e.resolved || !e.foldable) {
+ e.failAnalysis(s"cannot evaluate expression ${e.sql} in inline table definition")
+ }
+ }
+ }
+ }
+
+ /**
+ * Convert a valid (with right shape and foldable inputs) [[UnresolvedInlineTable]]
+ * into a [[LocalRelation]].
+ *
+ * This function attempts to coerce inputs into consistent types.
+ *
+ * This is package visible for unit testing.
+ */
+ private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
+ // For each column, traverse all the values and find a common data type and nullability.
+ val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
+ val inputTypes = column.map(_.dataType)
+ val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
+ table.failAnalysis(s"incompatible types found in column $name for inline table")
+ }
+ StructField(name, tpe, nullable = column.exists(_.nullable))
+ }
+ val attributes = StructType(fields).toAttributes
+ assert(fields.size == table.names.size)
+
+ val newRows: Seq[InternalRow] = table.rows.map { row =>
+ InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
+ val targetType = fields(ci).dataType
+ try {
+ if (e.dataType.sameType(targetType)) {
+ e.eval()
+ } else {
+ Cast(e, targetType).eval()
+ }
+ } catch {
+ case NonFatal(ex) =>
+ table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
+ }
+ })
+ }
+
+ LocalRelation(attributes, newRows)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 21e96aaf53..193c3ec4e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -150,7 +150,7 @@ object TypeCoercion {
* [[findTightestCommonType]], but can handle decimal types. If the wider decimal type exceeds
* system limitation, this rule will truncate the decimal type before return it.
*/
- private def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
+ def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case Some(d) => findTightestCommonTypeOfTwo(d, c).orElse((d, c) match {
case (t1: DecimalType, t2: DecimalType) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 3735a1501c..235ae04782 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -50,10 +50,30 @@ case class UnresolvedRelation(
}
/**
- * Holds a table-valued function call that has yet to be resolved.
+ * An inline table that has not been resolved yet. Once resolved, it is turned by the analyzer into
+ * a [[org.apache.spark.sql.catalyst.plans.logical.LocalRelation]].
+ *
+ * @param names list of column names
+ * @param rows expressions for the data
+ */
+case class UnresolvedInlineTable(
+ names: Seq[String],
+ rows: Seq[Seq[Expression]])
+ extends LeafNode {
+
+ lazy val expressionsResolved: Boolean = rows.forall(_.forall(_.resolved))
+ override lazy val resolved = false
+ override def output: Seq[Attribute] = Nil
+}
+
+/**
+ * A table-valued function, e.g.
+ * {{{
+ * select * from range(10);
+ * }}}
*/
-case class UnresolvedTableValuedFunction(
- functionName: String, functionArgs: Seq[Expression]) extends LeafNode {
+case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression])
+ extends LeafNode {
override def output: Seq[Attribute] = Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 01322ae327..283e4d43ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -670,39 +670,24 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
*/
override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) {
// Get the backing expressions.
- val expressions = ctx.expression.asScala.map { eCtx =>
- val e = expression(eCtx)
- validate(e.foldable, "All expressions in an inline table must be constants.", eCtx)
- e
- }
-
- // Validate and evaluate the rows.
- val (structType, structConstructor) = expressions.head.dataType match {
- case st: StructType =>
- (st, (e: Expression) => e)
- case dt =>
- val st = CreateStruct(Seq(expressions.head)).dataType
- (st, (e: Expression) => CreateStruct(Seq(e)))
- }
- val rows = expressions.map {
- case expression =>
- val safe = Cast(structConstructor(expression), structType)
- safe.eval().asInstanceOf[InternalRow]
+ val rows = ctx.expression.asScala.map { e =>
+ expression(e) match {
+ // inline table comes in two styles:
+ // style 1: values (1), (2), (3) -- multiple columns are supported
+ // style 2: values 1, 2, 3 -- only a single column is supported here
+ case CreateStruct(children) => children // style 1
+ case child => Seq(child) // style 2
+ }
}
- // Construct attributes.
- val baseAttributes = structType.toAttributes.map(_.withNullability(true))
- val attributes = if (ctx.identifierList != null) {
- val aliases = visitIdentifierList(ctx.identifierList)
- validate(aliases.size == baseAttributes.size,
- "Number of aliases must match the number of fields in an inline table.", ctx)
- baseAttributes.zip(aliases).map(p => p._1.withName(p._2))
+ val aliases = if (ctx.identifierList != null) {
+ visitIdentifierList(ctx.identifierList)
} else {
- baseAttributes
+ Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
}
- // Create plan and add an alias if a name has been defined.
- LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan)
+ val table = UnresolvedInlineTable(aliases, rows)
+ table.optionalMap(ctx.identifier)(aliasPlan)
}
/**