aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala132
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala8
-rw-r--r--sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql20
-rw-r--r--sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out87
8 files changed, 267 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 6122bcdef8..cab7c3ff5a 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -433,6 +433,7 @@ relationPrimary
| '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery
| '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation
| inlineTable #inlineTableDefault2
+ | identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction
;
inlineTable
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index cfab6ae7bd..333dd4d9a4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -86,6 +86,7 @@ class Analyzer(
EliminateUnions,
new SubstituteUnresolvedOrdinals(conf)),
Batch("Resolution", fixedPoint,
+ ResolveTableValuedFunctions ::
ResolveRelations ::
ResolveReferences ::
ResolveDeserializer ::
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
new file mode 100644
index 0000000000..7fdf7fa0c0
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types.{DataType, IntegerType, LongType}
+
+/**
+ * Rule that resolves table-valued function references.
+ */
+object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
+ private lazy val defaultParallelism =
+ SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism
+
+ /**
+ * List of argument names and their types, used to declare a function.
+ */
+ private case class ArgumentList(args: (String, DataType)*) {
+ /**
+ * Try to cast the expressions to satisfy the expected types of this argument list. If there
+ * are any types that cannot be casted, then None is returned.
+ */
+ def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = {
+ if (args.length == values.length) {
+ val casted = values.zip(args).map { case (value, (_, expectedType)) =>
+ TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType)
+ }
+ if (casted.forall(_.isDefined)) {
+ return Some(casted.map(_.get))
+ }
+ }
+ None
+ }
+
+ override def toString: String = {
+ args.map { a =>
+ s"${a._1}: ${a._2.typeName}"
+ }.mkString(", ")
+ }
+ }
+
+ /**
+ * A TVF maps argument lists to resolver functions that accept those arguments. Using a map
+ * here allows for function overloading.
+ */
+ private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan]
+
+ /**
+ * TVF builder.
+ */
+ private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan])
+ : (ArgumentList, Seq[Any] => LogicalPlan) = {
+ (ArgumentList(args: _*),
+ pf orElse {
+ case args =>
+ throw new IllegalArgumentException(
+ "Invalid arguments for resolved function: " + args.mkString(", "))
+ })
+ }
+
+ /**
+ * Internal registry of table-valued functions.
+ */
+ private val builtinFunctions: Map[String, TVF] = Map(
+ "range" -> Map(
+ /* range(end) */
+ tvf("end" -> LongType) { case Seq(end: Long) =>
+ Range(0, end, 1, defaultParallelism)
+ },
+
+ /* range(start, end) */
+ tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) =>
+ Range(start, end, 1, defaultParallelism)
+ },
+
+ /* range(start, end, step) */
+ tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) {
+ case Seq(start: Long, end: Long, step: Long) =>
+ Range(start, end, step, defaultParallelism)
+ },
+
+ /* range(start, end, step, numPartitions) */
+ tvf("start" -> LongType, "end" -> LongType, "step" -> LongType,
+ "numPartitions" -> IntegerType) {
+ case Seq(start: Long, end: Long, step: Long, numPartitions: Int) =>
+ Range(start, end, step, numPartitions)
+ })
+ )
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
+ builtinFunctions.get(u.functionName) match {
+ case Some(tvf) =>
+ val resolved = tvf.flatMap { case (argList, resolver) =>
+ argList.implicitCast(u.functionArgs) match {
+ case Some(casted) =>
+ Some(resolver(casted.map(_.eval())))
+ case _ =>
+ None
+ }
+ }
+ resolved.headOption.getOrElse {
+ val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ")
+ u.failAnalysis(
+ s"""error: table-valued function ${u.functionName} with alternatives:
+ |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")}
+ |cannot be applied to: (${argTypes})""".stripMargin)
+ }
+ case _ =>
+ u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 42e7aae0b6..3735a1501c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -50,6 +50,17 @@ case class UnresolvedRelation(
}
/**
+ * Holds a table-valued function call that has yet to be resolved.
+ */
+case class UnresolvedTableValuedFunction(
+ functionName: String, functionArgs: Seq[Expression]) extends LeafNode {
+
+ override def output: Seq[Attribute] = Nil
+
+ override lazy val resolved = false
+}
+
+/**
* Holds the name of an attribute that has yet to be resolved.
*/
case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Unevaluable {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index adf78396d7..01322ae327 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -658,6 +658,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
/**
+ * Create a table-valued function call with arguments, e.g. range(1000)
+ */
+ override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
+ : LogicalPlan = withOrigin(ctx) {
+ UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression))
+ }
+
+ /**
* Create an inline table (a virtual table in Hive parlance).
*/
override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 7af333b34f..cbe4a022e7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedTableValuedFunction}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -426,6 +426,12 @@ class PlanParserSuite extends PlanTest {
assertEqual("table d.t", table("d", "t"))
}
+ test("table valued function") {
+ assertEqual(
+ "select * from range(2)",
+ UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star()))
+ }
+
test("inline table") {
assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows(
Seq('col1.int),
diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
new file mode 100644
index 0000000000..2e6dcd538b
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
@@ -0,0 +1,20 @@
+-- unresolved function
+select * from dummy(3);
+
+-- range call with end
+select * from range(6 + cos(3));
+
+-- range call with start and end
+select * from range(5, 10);
+
+-- range call with step
+select * from range(0, 10, 2);
+
+-- range call with numPartitions
+select * from range(0, 10, 1, 200);
+
+-- range call error
+select * from range(1, 1, 1, 1, 1);
+
+-- range call with null
+select * from range(1, null);
diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
new file mode 100644
index 0000000000..d769bcef0a
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
@@ -0,0 +1,87 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 7
+
+
+-- !query 0
+select * from dummy(3)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+org.apache.spark.sql.AnalysisException
+could not resolve `dummy` to a table-valued function; line 1 pos 14
+
+
+-- !query 1
+select * from range(6 + cos(3))
+-- !query 1 schema
+struct<id:bigint>
+-- !query 1 output
+0
+1
+2
+3
+4
+
+
+-- !query 2
+select * from range(5, 10)
+-- !query 2 schema
+struct<id:bigint>
+-- !query 2 output
+5
+6
+7
+8
+9
+
+
+-- !query 3
+select * from range(0, 10, 2)
+-- !query 3 schema
+struct<id:bigint>
+-- !query 3 output
+0
+2
+4
+6
+8
+
+
+-- !query 4
+select * from range(0, 10, 1, 200)
+-- !query 4 schema
+struct<id:bigint>
+-- !query 4 output
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+
+
+-- !query 5
+select * from range(1, 1, 1, 1, 1)
+-- !query 5 schema
+struct<>
+-- !query 5 output
+org.apache.spark.sql.AnalysisException
+error: table-valued function range with alternatives:
+ (end: long)
+ (start: long, end: long)
+ (start: long, end: long, step: long)
+ (start: long, end: long, step: long, numPartitions: integer)
+cannot be applied to: (integer, integer, integer, integer, integer); line 1 pos 14
+
+
+-- !query 6
+select * from range(1, null)
+-- !query 6 schema
+struct<>
+-- !query 6 output
+java.lang.IllegalArgumentException
+Invalid arguments for resolved function: 1, null