aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2016-08-18 13:33:55 +0200
committerHerman van Hovell <hvanhovell@databricks.com>2016-08-18 13:33:55 +0200
commit412dba63b511474a6db3c43c8618d803e604bc6b (patch)
tree2247e6674c4a49a934b98b27fd8a2d7ff7d177dc /sql/catalyst/src/main
parentb81421afb04959bb22b53653be0a09c1f1c5845f (diff)
downloadspark-412dba63b511474a6db3c43c8618d803e604bc6b.tar.gz
spark-412dba63b511474a6db3c43c8618d803e604bc6b.tar.bz2
spark-412dba63b511474a6db3c43c8618d803e604bc6b.zip
[SPARK-17069] Expose spark.range() as table-valued function in SQL
## What changes were proposed in this pull request? This adds analyzer rules for resolving table-valued functions, and adds one builtin implementation for range(). The arguments for range() are the same as those of `spark.range()`. ## How was this patch tested? Unit tests. cc hvanhovell Author: Eric Liang <ekl@databricks.com> Closes #14656 from ericl/sc-4309.
Diffstat (limited to 'sql/catalyst/src/main')
-rw-r--r--sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala132
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala8
5 files changed, 153 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 6122bcdef8..cab7c3ff5a 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -433,6 +433,7 @@ relationPrimary
| '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery
| '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation
| inlineTable #inlineTableDefault2
+ | identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction
;
inlineTable
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index cfab6ae7bd..333dd4d9a4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -86,6 +86,7 @@ class Analyzer(
EliminateUnions,
new SubstituteUnresolvedOrdinals(conf)),
Batch("Resolution", fixedPoint,
+ ResolveTableValuedFunctions ::
ResolveRelations ::
ResolveReferences ::
ResolveDeserializer ::
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
new file mode 100644
index 0000000000..7fdf7fa0c0
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types.{DataType, IntegerType, LongType}
+
+/**
+ * Rule that resolves table-valued function references.
+ */
+object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
+ private lazy val defaultParallelism =
+ SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism
+
+ /**
+ * List of argument names and their types, used to declare a function.
+ */
+ private case class ArgumentList(args: (String, DataType)*) {
+ /**
+ * Try to cast the expressions to satisfy the expected types of this argument list. If there
+ * are any types that cannot be casted, then None is returned.
+ */
+ def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = {
+ if (args.length == values.length) {
+ val casted = values.zip(args).map { case (value, (_, expectedType)) =>
+ TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType)
+ }
+ if (casted.forall(_.isDefined)) {
+ return Some(casted.map(_.get))
+ }
+ }
+ None
+ }
+
+ override def toString: String = {
+ args.map { a =>
+ s"${a._1}: ${a._2.typeName}"
+ }.mkString(", ")
+ }
+ }
+
+ /**
+ * A TVF maps argument lists to resolver functions that accept those arguments. Using a map
+ * here allows for function overloading.
+ */
+ private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan]
+
+ /**
+ * TVF builder.
+ */
+ private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan])
+ : (ArgumentList, Seq[Any] => LogicalPlan) = {
+ (ArgumentList(args: _*),
+ pf orElse {
+ case args =>
+ throw new IllegalArgumentException(
+ "Invalid arguments for resolved function: " + args.mkString(", "))
+ })
+ }
+
+ /**
+ * Internal registry of table-valued functions.
+ */
+ private val builtinFunctions: Map[String, TVF] = Map(
+ "range" -> Map(
+ /* range(end) */
+ tvf("end" -> LongType) { case Seq(end: Long) =>
+ Range(0, end, 1, defaultParallelism)
+ },
+
+ /* range(start, end) */
+ tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) =>
+ Range(start, end, 1, defaultParallelism)
+ },
+
+ /* range(start, end, step) */
+ tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) {
+ case Seq(start: Long, end: Long, step: Long) =>
+ Range(start, end, step, defaultParallelism)
+ },
+
+ /* range(start, end, step, numPartitions) */
+ tvf("start" -> LongType, "end" -> LongType, "step" -> LongType,
+ "numPartitions" -> IntegerType) {
+ case Seq(start: Long, end: Long, step: Long, numPartitions: Int) =>
+ Range(start, end, step, numPartitions)
+ })
+ )
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
+ builtinFunctions.get(u.functionName) match {
+ case Some(tvf) =>
+ val resolved = tvf.flatMap { case (argList, resolver) =>
+ argList.implicitCast(u.functionArgs) match {
+ case Some(casted) =>
+ Some(resolver(casted.map(_.eval())))
+ case _ =>
+ None
+ }
+ }
+ resolved.headOption.getOrElse {
+ val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ")
+ u.failAnalysis(
+ s"""error: table-valued function ${u.functionName} with alternatives:
+ |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")}
+ |cannot be applied to: (${argTypes})""".stripMargin)
+ }
+ case _ =>
+ u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 42e7aae0b6..3735a1501c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -50,6 +50,17 @@ case class UnresolvedRelation(
}
/**
+ * Holds a table-valued function call that has yet to be resolved.
+ */
+case class UnresolvedTableValuedFunction(
+ functionName: String, functionArgs: Seq[Expression]) extends LeafNode {
+
+ override def output: Seq[Attribute] = Nil
+
+ override lazy val resolved = false
+}
+
+/**
* Holds the name of an attribute that has yet to be resolved.
*/
case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Unevaluable {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index adf78396d7..01322ae327 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -658,6 +658,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
/**
+ * Create a table-valued function call with arguments, e.g. range(1000)
+ */
+ override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
+ : LogicalPlan = withOrigin(ctx) {
+ UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression))
+ }
+
+ /**
* Create an inline table (a virtual table in Hive parlance).
*/
override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) {