diff options
author | Eric Liang <ekl@databricks.com> | 2016-08-18 13:33:55 +0200 |
---|---|---|
committer | Herman van Hovell <hvanhovell@databricks.com> | 2016-08-18 13:33:55 +0200 |
commit | 412dba63b511474a6db3c43c8618d803e604bc6b (patch) | |
tree | 2247e6674c4a49a934b98b27fd8a2d7ff7d177dc /sql/catalyst | |
parent | b81421afb04959bb22b53653be0a09c1f1c5845f (diff) | |
download | spark-412dba63b511474a6db3c43c8618d803e604bc6b.tar.gz spark-412dba63b511474a6db3c43c8618d803e604bc6b.tar.bz2 spark-412dba63b511474a6db3c43c8618d803e604bc6b.zip |
[SPARK-17069] Expose spark.range() as table-valued function in SQL
## What changes were proposed in this pull request?
This adds analyzer rules for resolving table-valued functions, and adds one builtin implementation for range(). The arguments for range() are the same as those of `spark.range()`.
## How was this patch tested?
Unit tests.
cc hvanhovell
Author: Eric Liang <ekl@databricks.com>
Closes #14656 from ericl/sc-4309.
Diffstat (limited to 'sql/catalyst')
6 files changed, 160 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6122bcdef8..cab7c3ff5a 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -433,6 +433,7 @@ relationPrimary | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation | inlineTable #inlineTableDefault2 + | identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction ; inlineTable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cfab6ae7bd..333dd4d9a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -86,6 +86,7 @@ class Analyzer( EliminateUnions, new SubstituteUnresolvedOrdinals(conf)), Batch("Resolution", fixedPoint, + ResolveTableValuedFunctions :: ResolveRelations :: ResolveReferences :: ResolveDeserializer :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala new file mode 100644 index 0000000000..7fdf7fa0c0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.{DataType, IntegerType, LongType} + +/** + * Rule that resolves table-valued function references. + */ +object ResolveTableValuedFunctions extends Rule[LogicalPlan] { + private lazy val defaultParallelism = + SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism + + /** + * List of argument names and their types, used to declare a function. + */ + private case class ArgumentList(args: (String, DataType)*) { + /** + * Try to cast the expressions to satisfy the expected types of this argument list. If there + * are any types that cannot be casted, then None is returned. + */ + def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = { + if (args.length == values.length) { + val casted = values.zip(args).map { case (value, (_, expectedType)) => + TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType) + } + if (casted.forall(_.isDefined)) { + return Some(casted.map(_.get)) + } + } + None + } + + override def toString: String = { + args.map { a => + s"${a._1}: ${a._2.typeName}" + }.mkString(", ") + } + } + + /** + * A TVF maps argument lists to resolver functions that accept those arguments. Using a map + * here allows for function overloading. + */ + private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan] + + /** + * TVF builder. + */ + private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan]) + : (ArgumentList, Seq[Any] => LogicalPlan) = { + (ArgumentList(args: _*), + pf orElse { + case args => + throw new IllegalArgumentException( + "Invalid arguments for resolved function: " + args.mkString(", ")) + }) + } + + /** + * Internal registry of table-valued functions. + */ + private val builtinFunctions: Map[String, TVF] = Map( + "range" -> Map( + /* range(end) */ + tvf("end" -> LongType) { case Seq(end: Long) => + Range(0, end, 1, defaultParallelism) + }, + + /* range(start, end) */ + tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) => + Range(start, end, 1, defaultParallelism) + }, + + /* range(start, end, step) */ + tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) { + case Seq(start: Long, end: Long, step: Long) => + Range(start, end, step, defaultParallelism) + }, + + /* range(start, end, step, numPartitions) */ + tvf("start" -> LongType, "end" -> LongType, "step" -> LongType, + "numPartitions" -> IntegerType) { + case Seq(start: Long, end: Long, step: Long, numPartitions: Int) => + Range(start, end, step, numPartitions) + }) + ) + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => + builtinFunctions.get(u.functionName) match { + case Some(tvf) => + val resolved = tvf.flatMap { case (argList, resolver) => + argList.implicitCast(u.functionArgs) match { + case Some(casted) => + Some(resolver(casted.map(_.eval()))) + case _ => + None + } + } + resolved.headOption.getOrElse { + val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") + u.failAnalysis( + s"""error: table-valued function ${u.functionName} with alternatives: + |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} + |cannot be applied to: (${argTypes})""".stripMargin) + } + case _ => + u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 42e7aae0b6..3735a1501c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -50,6 +50,17 @@ case class UnresolvedRelation( } /** + * Holds a table-valued function call that has yet to be resolved. + */ +case class UnresolvedTableValuedFunction( + functionName: String, functionArgs: Seq[Expression]) extends LeafNode { + + override def output: Seq[Attribute] = Nil + + override lazy val resolved = false +} + +/** * Holds the name of an attribute that has yet to be resolved. */ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Unevaluable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index adf78396d7..01322ae327 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -658,6 +658,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** + * Create a table-valued function call with arguments, e.g. range(1000) + */ + override def visitTableValuedFunction(ctx: TableValuedFunctionContext) + : LogicalPlan = withOrigin(ctx) { + UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression)) + } + + /** * Create an inline table (a virtual table in Hive parlance). */ override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 7af333b34f..cbe4a022e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator +import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -426,6 +426,12 @@ class PlanParserSuite extends PlanTest { assertEqual("table d.t", table("d", "t")) } + test("table valued function") { + assertEqual( + "select * from range(2)", + UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star())) + } + test("inline table") { assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( Seq('col1.int), |