From f5472dda51b980a726346587257c22873ff708e3 Mon Sep 17 00:00:00 2001 From: petermaxlee Date: Fri, 19 Aug 2016 09:19:47 +0800 Subject: [SPARK-16947][SQL] Support type coercion and foldable expression for inline tables ## What changes were proposed in this pull request? This patch improves inline table support with the following: 1. Support type coercion. 2. Support using foldable expressions. Previously only literals were supported. 3. Improve error message handling. 4. Improve test coverage. ## How was this patch tested? Added a new unit test suite ResolveInlineTablesSuite and a new file-based end-to-end test inline-table.sql. Author: petermaxlee Closes #14676 from petermaxlee/SPARK-16947. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 1 + .../catalyst/analysis/ResolveInlineTables.scala | 112 +++++++++++++++++++++ .../spark/sql/catalyst/analysis/TypeCoercion.scala | 2 +- .../spark/sql/catalyst/analysis/unresolved.scala | 26 ++++- .../spark/sql/catalyst/parser/AstBuilder.scala | 41 +++----- .../analysis/ResolveInlineTablesSuite.scala | 101 +++++++++++++++++++ .../sql/catalyst/parser/PlanParserSuite.scala | 22 ++-- 7 files changed, 259 insertions(+), 46 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 333dd4d9a4..41e0e6d65e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -108,6 +108,7 @@ class Analyzer( GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: + ResolveInlineTables :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala new file mode 100644 index 0000000000..7323197b10 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.util.control.NonFatal + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. + */ +object ResolveInlineTables extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case table: UnresolvedInlineTable if table.expressionsResolved => + validateInputDimension(table) + validateInputEvaluable(table) + convert(table) + } + + /** + * Validates the input data dimension: + * 1. All rows have the same cardinality. + * 2. The number of column aliases defined is consistent with the number of columns in data. + * + * This is package visible for unit testing. + */ + private[analysis] def validateInputDimension(table: UnresolvedInlineTable): Unit = { + if (table.rows.nonEmpty) { + val numCols = table.names.size + table.rows.zipWithIndex.foreach { case (row, ri) => + if (row.size != numCols) { + table.failAnalysis(s"expected $numCols columns but found ${row.size} columns in row $ri") + } + } + } + } + + /** + * Validates that all inline table data are valid expressions that can be evaluated + * (in this they must be foldable). + * + * This is package visible for unit testing. + */ + private[analysis] def validateInputEvaluable(table: UnresolvedInlineTable): Unit = { + table.rows.foreach { row => + row.foreach { e => + // Note that nondeterministic expressions are not supported since they are not foldable. + if (!e.resolved || !e.foldable) { + e.failAnalysis(s"cannot evaluate expression ${e.sql} in inline table definition") + } + } + } + } + + /** + * Convert a valid (with right shape and foldable inputs) [[UnresolvedInlineTable]] + * into a [[LocalRelation]]. + * + * This function attempts to coerce inputs into consistent types. + * + * This is package visible for unit testing. + */ + private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = { + // For each column, traverse all the values and find a common data type and nullability. + val fields = table.rows.transpose.zip(table.names).map { case (column, name) => + val inputTypes = column.map(_.dataType) + val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { + table.failAnalysis(s"incompatible types found in column $name for inline table") + } + StructField(name, tpe, nullable = column.exists(_.nullable)) + } + val attributes = StructType(fields).toAttributes + assert(fields.size == table.names.size) + + val newRows: Seq[InternalRow] = table.rows.map { row => + InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) => + val targetType = fields(ci).dataType + try { + if (e.dataType.sameType(targetType)) { + e.eval() + } else { + Cast(e, targetType).eval() + } + } catch { + case NonFatal(ex) => + table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") + } + }) + } + + LocalRelation(attributes, newRows) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 21e96aaf53..193c3ec4e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -150,7 +150,7 @@ object TypeCoercion { * [[findTightestCommonType]], but can handle decimal types. If the wider decimal type exceeds * system limitation, this rule will truncate the decimal type before return it. */ - private def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { + def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case Some(d) => findTightestCommonTypeOfTwo(d, c).orElse((d, c) match { case (t1: DecimalType, t2: DecimalType) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 3735a1501c..235ae04782 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -50,10 +50,30 @@ case class UnresolvedRelation( } /** - * Holds a table-valued function call that has yet to be resolved. + * An inline table that has not been resolved yet. Once resolved, it is turned by the analyzer into + * a [[org.apache.spark.sql.catalyst.plans.logical.LocalRelation]]. + * + * @param names list of column names + * @param rows expressions for the data + */ +case class UnresolvedInlineTable( + names: Seq[String], + rows: Seq[Seq[Expression]]) + extends LeafNode { + + lazy val expressionsResolved: Boolean = rows.forall(_.forall(_.resolved)) + override lazy val resolved = false + override def output: Seq[Attribute] = Nil +} + +/** + * A table-valued function, e.g. + * {{{ + * select * from range(10); + * }}} */ -case class UnresolvedTableValuedFunction( - functionName: String, functionArgs: Seq[Expression]) extends LeafNode { +case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression]) + extends LeafNode { override def output: Seq[Attribute] = Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 01322ae327..283e4d43ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -670,39 +670,24 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { // Get the backing expressions. - val expressions = ctx.expression.asScala.map { eCtx => - val e = expression(eCtx) - validate(e.foldable, "All expressions in an inline table must be constants.", eCtx) - e - } - - // Validate and evaluate the rows. - val (structType, structConstructor) = expressions.head.dataType match { - case st: StructType => - (st, (e: Expression) => e) - case dt => - val st = CreateStruct(Seq(expressions.head)).dataType - (st, (e: Expression) => CreateStruct(Seq(e))) - } - val rows = expressions.map { - case expression => - val safe = Cast(structConstructor(expression), structType) - safe.eval().asInstanceOf[InternalRow] + val rows = ctx.expression.asScala.map { e => + expression(e) match { + // inline table comes in two styles: + // style 1: values (1), (2), (3) -- multiple columns are supported + // style 2: values 1, 2, 3 -- only a single column is supported here + case CreateStruct(children) => children // style 1 + case child => Seq(child) // style 2 + } } - // Construct attributes. - val baseAttributes = structType.toAttributes.map(_.withNullability(true)) - val attributes = if (ctx.identifierList != null) { - val aliases = visitIdentifierList(ctx.identifierList) - validate(aliases.size == baseAttributes.size, - "Number of aliases must match the number of fields in an inline table.", ctx) - baseAttributes.zip(aliases).map(p => p._1.withName(p._2)) + val aliases = if (ctx.identifierList != null) { + visitIdentifierList(ctx.identifierList) } else { - baseAttributes + Seq.tabulate(rows.head.size)(i => s"col${i + 1}") } - // Create plan and add an alias if a name has been defined. - LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan) + val table = UnresolvedInlineTable(aliases, rows) + table.optionalMap(ctx.identifier)(aliasPlan) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala new file mode 100644 index 0000000000..920c6ea50f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Literal, Rand} +import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types.{LongType, NullType} + +/** + * Unit tests for [[ResolveInlineTables]]. Note that there are also test cases defined in + * end-to-end tests (in sql/core module) for verifying the correct error messages are shown + * in negative cases. + */ +class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { + + private def lit(v: Any): Literal = Literal(v) + + test("validate inputs are foldable") { + ResolveInlineTables.validateInputEvaluable( + UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) + + // nondeterministic (rand) should not work + intercept[AnalysisException] { + ResolveInlineTables.validateInputEvaluable( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1))))) + } + + // aggregate should not work + intercept[AnalysisException] { + ResolveInlineTables.validateInputEvaluable( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1)))))) + } + + // unresolved attribute should not work + intercept[AnalysisException] { + ResolveInlineTables.validateInputEvaluable( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A"))))) + } + } + + test("validate input dimensions") { + ResolveInlineTables.validateInputDimension( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) + + // num alias != data dimension + intercept[AnalysisException] { + ResolveInlineTables.validateInputDimension( + UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) + } + + // num alias == data dimension, but data themselves are inconsistent + intercept[AnalysisException] { + ResolveInlineTables.validateInputDimension( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) + } + } + + test("do not fire the rule if not all expressions are resolved") { + val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) + assert(ResolveInlineTables(table) == table) + } + + test("convert") { + val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) + val converted = ResolveInlineTables.convert(table) + + assert(converted.output.map(_.dataType) == Seq(LongType)) + assert(converted.data.size == 2) + assert(converted.data(0).getLong(0) == 1L) + assert(converted.data(1).getLong(0) == 2L) + } + + test("nullability inference in convert") { + val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) + val converted1 = ResolveInlineTables.convert(table1) + assert(!converted1.schema.fields(0).nullable) + + val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) + val converted2 = ResolveInlineTables.convert(table2) + assert(converted2.schema.fields(0).nullable) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index cbe4a022e7..2fcbfc7067 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.catalyst.parser -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedTableValuedFunction} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedInlineTable, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -433,19 +432,14 @@ class PlanParserSuite extends PlanTest { } test("inline table") { - assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( - Seq('col1.int), - Seq(1, 2, 3, 4).map(x => Row(x)))) + assertEqual("values 1, 2, 3, 4", + UnresolvedInlineTable(Seq("col1"), Seq(1, 2, 3, 4).map(x => Seq(Literal(x))))) + assertEqual( - "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", - LocalRelation.fromExternalRows( - Seq('a.int, 'b.string), - Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) - intercept("values (a, 'a'), (b, 'b')", - "All expressions in an inline table must be constants.") - intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", - "Number of aliases must match the number of fields in an inline table.") - intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) + "values (1, 'a'), (2, 'b') as tbl(a, b)", + UnresolvedInlineTable( + Seq("a", "b"), + Seq(Literal(1), Literal("a")) :: Seq(Literal(2), Literal("b")) :: Nil).as("tbl")) } test("simple select query with !> and !<") { -- cgit v1.2.3