From 412dba63b511474a6db3c43c8618d803e604bc6b Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 18 Aug 2016 13:33:55 +0200 Subject: [SPARK-17069] Expose spark.range() as table-valued function in SQL ## What changes were proposed in this pull request? This adds analyzer rules for resolving table-valued functions, and adds one builtin implementation for range(). The arguments for range() are the same as those of `spark.range()`. ## How was this patch tested? Unit tests. cc hvanhovell Author: Eric Liang Closes #14656 from ericl/sc-4309. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 1 + .../spark/sql/catalyst/analysis/Analyzer.scala | 1 + .../analysis/ResolveTableValuedFunctions.scala | 132 +++++++++++++++++++++ .../spark/sql/catalyst/analysis/unresolved.scala | 11 ++ .../spark/sql/catalyst/parser/AstBuilder.scala | 8 ++ 5 files changed, 153 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala (limited to 'sql/catalyst/src/main') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6122bcdef8..cab7c3ff5a 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -433,6 +433,7 @@ relationPrimary | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation | inlineTable #inlineTableDefault2 + | identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction ; inlineTable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cfab6ae7bd..333dd4d9a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -86,6 +86,7 @@ class Analyzer( EliminateUnions, new SubstituteUnresolvedOrdinals(conf)), Batch("Resolution", fixedPoint, + ResolveTableValuedFunctions :: ResolveRelations :: ResolveReferences :: ResolveDeserializer :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala new file mode 100644 index 0000000000..7fdf7fa0c0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.{DataType, IntegerType, LongType} + +/** + * Rule that resolves table-valued function references. + */ +object ResolveTableValuedFunctions extends Rule[LogicalPlan] { + private lazy val defaultParallelism = + SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism + + /** + * List of argument names and their types, used to declare a function. + */ + private case class ArgumentList(args: (String, DataType)*) { + /** + * Try to cast the expressions to satisfy the expected types of this argument list. If there + * are any types that cannot be casted, then None is returned. + */ + def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = { + if (args.length == values.length) { + val casted = values.zip(args).map { case (value, (_, expectedType)) => + TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType) + } + if (casted.forall(_.isDefined)) { + return Some(casted.map(_.get)) + } + } + None + } + + override def toString: String = { + args.map { a => + s"${a._1}: ${a._2.typeName}" + }.mkString(", ") + } + } + + /** + * A TVF maps argument lists to resolver functions that accept those arguments. Using a map + * here allows for function overloading. + */ + private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan] + + /** + * TVF builder. + */ + private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan]) + : (ArgumentList, Seq[Any] => LogicalPlan) = { + (ArgumentList(args: _*), + pf orElse { + case args => + throw new IllegalArgumentException( + "Invalid arguments for resolved function: " + args.mkString(", ")) + }) + } + + /** + * Internal registry of table-valued functions. + */ + private val builtinFunctions: Map[String, TVF] = Map( + "range" -> Map( + /* range(end) */ + tvf("end" -> LongType) { case Seq(end: Long) => + Range(0, end, 1, defaultParallelism) + }, + + /* range(start, end) */ + tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) => + Range(start, end, 1, defaultParallelism) + }, + + /* range(start, end, step) */ + tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) { + case Seq(start: Long, end: Long, step: Long) => + Range(start, end, step, defaultParallelism) + }, + + /* range(start, end, step, numPartitions) */ + tvf("start" -> LongType, "end" -> LongType, "step" -> LongType, + "numPartitions" -> IntegerType) { + case Seq(start: Long, end: Long, step: Long, numPartitions: Int) => + Range(start, end, step, numPartitions) + }) + ) + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => + builtinFunctions.get(u.functionName) match { + case Some(tvf) => + val resolved = tvf.flatMap { case (argList, resolver) => + argList.implicitCast(u.functionArgs) match { + case Some(casted) => + Some(resolver(casted.map(_.eval()))) + case _ => + None + } + } + resolved.headOption.getOrElse { + val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") + u.failAnalysis( + s"""error: table-valued function ${u.functionName} with alternatives: + |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} + |cannot be applied to: (${argTypes})""".stripMargin) + } + case _ => + u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 42e7aae0b6..3735a1501c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -49,6 +49,17 @@ case class UnresolvedRelation( override lazy val resolved = false } +/** + * Holds a table-valued function call that has yet to be resolved. + */ +case class UnresolvedTableValuedFunction( + functionName: String, functionArgs: Seq[Expression]) extends LeafNode { + + override def output: Seq[Attribute] = Nil + + override lazy val resolved = false +} + /** * Holds the name of an attribute that has yet to be resolved. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index adf78396d7..01322ae327 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -657,6 +657,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { table.optionalMap(ctx.sample)(withSample) } + /** + * Create a table-valued function call with arguments, e.g. range(1000) + */ + override def visitTableValuedFunction(ctx: TableValuedFunctionContext) + : LogicalPlan = withOrigin(ctx) { + UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression)) + } + /** * Create an inline table (a virtual table in Hive parlance). */ -- cgit v1.2.3