From f36c9c83798877256efa1447a6b9be5aa47a7e87 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 25 Apr 2016 16:20:57 -0700 Subject: [SPARK-14888][SQL] UnresolvedFunction should use FunctionIdentifier ## What changes were proposed in this pull request? This patch changes UnresolvedFunction and UnresolvedGenerator to use a FunctionIdentifier rather than just a String for function name. Also changed SessionCatalog to accept FunctionIdentifier in lookupFunction. ## How was this patch tested? Updated related unit tests. Author: Reynold Xin Closes #12659 from rxin/SPARK-14888. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 4 +- .../sql/catalyst/analysis/FunctionRegistry.scala | 2 +- .../spark/sql/catalyst/analysis/unresolved.scala | 22 +++++-- .../sql/catalyst/catalog/SessionCatalog.scala | 76 +++++++++++----------- .../apache/spark/sql/catalyst/identifiers.scala | 14 +++- .../spark/sql/catalyst/parser/AstBuilder.scala | 19 ++++-- .../sql/catalyst/catalog/SessionCatalogSuite.scala | 18 ++--- .../catalyst/parser/ExpressionParserSuite.scala | 4 +- .../sql/catalyst/parser/PlanParserSuite.scala | 6 +- 9 files changed, 100 insertions(+), 65 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 50957e8661..e37d9760cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -838,9 +838,9 @@ class Analyzer( s"its class is ${other.getClass.getCanonicalName}, which is not a generator.") } } - case u @ UnresolvedFunction(name, children, isDistinct) => + case u @ UnresolvedFunction(funcId, children, isDistinct) => withPosition(u) { - catalog.lookupFunction(name, children) match { + catalog.lookupFunction(funcId, children) match { // DISTINCT is not meaningful for a Max or a Min. case max: Max if isDistinct => AggregateExpression(max, Complete, isDistinct = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a44430059d..dd3577063f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -387,7 +387,7 @@ object FunctionRegistry { } /** See usage above. */ - def expression[T <: Expression](name: String) + private def expression[T <: Expression](name: String) (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { // See if we can find a constructor that accepts Seq[Expression] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index e83008e86e..f82b63ad96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{errors, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} @@ -30,8 +31,8 @@ import org.apache.spark.sql.types.{DataType, StructType} * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully * resolved. */ -class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: String) extends - errors.TreeNodeException(tree, s"Invalid call to $function on unresolved object", null) +class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: String) + extends TreeNodeException(tree, s"Invalid call to $function on unresolved object", null) /** * Holds the name of a relation that has yet to be looked up in a catalog. @@ -138,7 +139,8 @@ object UnresolvedAttribute { * the [[org.apache.spark.sql.catalyst.plans.logical.Generate]] operator. * The analyzer will resolve this generator. */ -case class UnresolvedGenerator(name: String, children: Seq[Expression]) extends Generator { +case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expression]) + extends Generator { override def elementTypes: Seq[(DataType, Boolean, String)] = throw new UnresolvedException(this, "elementTypes") @@ -147,7 +149,7 @@ case class UnresolvedGenerator(name: String, children: Seq[Expression]) extends override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def prettyName: String = name + override def prettyName: String = name.unquotedString override def toString: String = s"'$name(${children.mkString(", ")})" override def eval(input: InternalRow = null): TraversableOnce[InternalRow] = @@ -161,7 +163,7 @@ case class UnresolvedGenerator(name: String, children: Seq[Expression]) extends } case class UnresolvedFunction( - name: String, + name: FunctionIdentifier, children: Seq[Expression], isDistinct: Boolean) extends Expression with Unevaluable { @@ -171,10 +173,16 @@ case class UnresolvedFunction( override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def prettyName: String = name + override def prettyName: String = name.unquotedString override def toString: String = s"'$name(${children.mkString(", ")})" } +object UnresolvedFunction { + def apply(name: String, children: Seq[Expression], isDistinct: Boolean): UnresolvedFunction = { + UnresolvedFunction(FunctionIdentifier(name, None), children, isDistinct) + } +} + /** * Represents all of the input attributes to a given relational operator, for example in * "SELECT * FROM ...". A [[Star]] gets automatically expanded during analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 67b1752ee8..b36a76a888 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -59,13 +59,14 @@ class SessionCatalog( this(externalCatalog, new SimpleFunctionRegistry, new SimpleCatalystConf(true)) } - protected[this] val tempTables = new mutable.HashMap[String, LogicalPlan] + /** List of temporary tables, mapping from table name to their logical plan. */ + protected val tempTables = new mutable.HashMap[String, LogicalPlan] // Note: we track current database here because certain operations do not explicitly // specify the database (e.g. DROP TABLE my_table). In these cases we must first // check whether the temporary table or function exists, then, if not, operate on // the corresponding item in the current database. - protected[this] var currentDb = { + protected var currentDb = { val defaultName = "default" val defaultDbDefinition = CatalogDatabase(defaultName, "default database", "", Map()) // Initialize default database if it doesn't already exist @@ -118,7 +119,7 @@ class SessionCatalog( def setCurrentDatabase(db: String): Unit = { if (!databaseExists(db)) { - throw new AnalysisException(s"cannot set current database to non-existent '$db'") + throw new AnalysisException(s"Database '$db' does not exist.") } currentDb = db } @@ -593,9 +594,6 @@ class SessionCatalog( /** * Drop a temporary function. */ - // TODO: The reason that we distinguish dropFunction and dropTempFunction is that - // Hive has DROP FUNCTION and DROP TEMPORARY FUNCTION. We may want to consolidate - // dropFunction and dropTempFunction. def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = { if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) { throw new AnalysisException( @@ -622,40 +620,44 @@ class SessionCatalog( * based on the function class and put the builder into the FunctionRegistry. * The name of this function in the FunctionRegistry will be `databaseName.functionName`. */ - def lookupFunction(name: String, children: Seq[Expression]): Expression = { - // TODO: Right now, the name can be qualified or not qualified. - // It will be better to get a FunctionIdentifier. - // TODO: Right now, we assume that name is not qualified! - val qualifiedName = FunctionIdentifier(name, Some(currentDb)).unquotedString - if (functionRegistry.functionExists(name)) { + def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { + // Note: the implementation of this function is a little bit convoluted. + // We probably shouldn't use a single FunctionRegistry to register all three kinds of functions + // (built-in, temp, and external). + if (name.database.isEmpty && functionRegistry.functionExists(name.funcName)) { // This function has been already loaded into the function registry. - functionRegistry.lookupFunction(name, children) - } else if (functionRegistry.functionExists(qualifiedName)) { + return functionRegistry.lookupFunction(name.funcName, children) + } + + // If the name itself is not qualified, add the current database to it. + val qualifiedName = if (name.database.isEmpty) name.copy(database = Some(currentDb)) else name + + if (functionRegistry.functionExists(qualifiedName.unquotedString)) { // This function has been already loaded into the function registry. // Unlike the above block, we find this function by using the qualified name. - functionRegistry.lookupFunction(qualifiedName, children) - } else { - // The function has not been loaded to the function registry, which means - // that the function is a permanent function (if it actually has been registered - // in the metastore). We need to first put the function in the FunctionRegistry. - val catalogFunction = try { - externalCatalog.getFunction(currentDb, name) - } catch { - case e: AnalysisException => failFunctionLookup(name) - case e: NoSuchFunctionException => failFunctionLookup(name) - } - loadFunctionResources(catalogFunction.resources) - // Please note that qualifiedName is provided by the user. However, - // catalogFunction.identifier.unquotedString is returned by the underlying - // catalog. So, it is possible that qualifiedName is not exactly the same as - // catalogFunction.identifier.unquotedString (difference is on case-sensitivity). - // At here, we preserve the input from the user. - val info = new ExpressionInfo(catalogFunction.className, qualifiedName) - val builder = makeFunctionBuilder(qualifiedName, catalogFunction.className) - createTempFunction(qualifiedName, info, builder, ignoreIfExists = false) - // Now, we need to create the Expression. - functionRegistry.lookupFunction(qualifiedName, children) + return functionRegistry.lookupFunction(qualifiedName.unquotedString, children) + } + + // The function has not been loaded to the function registry, which means + // that the function is a permanent function (if it actually has been registered + // in the metastore). We need to first put the function in the FunctionRegistry. + val catalogFunction = try { + externalCatalog.getFunction(currentDb, name.funcName) + } catch { + case e: AnalysisException => failFunctionLookup(name.funcName) + case e: NoSuchFunctionException => failFunctionLookup(name.funcName) } + loadFunctionResources(catalogFunction.resources) + // Please note that qualifiedName is provided by the user. However, + // catalogFunction.identifier.unquotedString is returned by the underlying + // catalog. So, it is possible that qualifiedName is not exactly the same as + // catalogFunction.identifier.unquotedString (difference is on case-sensitivity). + // At here, we preserve the input from the user. + val info = new ExpressionInfo(catalogFunction.className, qualifiedName.unquotedString) + val builder = makeFunctionBuilder(qualifiedName.unquotedString, catalogFunction.className) + createTempFunction(qualifiedName.unquotedString, info, builder, ignoreIfExists = false) + // Now, we need to create the Expression. + return functionRegistry.lookupFunction(qualifiedName.unquotedString, children) } /** @@ -671,8 +673,6 @@ class SessionCatalog( externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) } val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern) .map { f => FunctionIdentifier(f) } - // TODO: Actually, there will be dbFunctions that have been loaded into the FunctionRegistry. - // So, the returned list may have two entries for the same function. dbFunctions ++ loadedFunctions } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index aae75956ea..7d0584511f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -26,9 +26,17 @@ package org.apache.spark.sql.catalyst */ sealed trait IdentifierWithDatabase { val identifier: String + def database: Option[String] - def quotedString: String = database.map(db => s"`$db`.`$identifier`").getOrElse(s"`$identifier`") - def unquotedString: String = database.map(db => s"$db.$identifier").getOrElse(identifier) + + def quotedString: String = { + if (database.isDefined) s"`${database.get}`.`$identifier`" else s"`$identifier`" + } + + def unquotedString: String = { + if (database.isDefined) s"${database.get}.$identifier" else identifier + } + override def toString: String = quotedString } @@ -63,6 +71,8 @@ case class FunctionIdentifier(funcName: String, database: Option[String]) override val identifier: String = funcName def this(funcName: String) = this(funcName, None) + + override def toString: String = unquotedString } object FunctionIdentifier { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 1c067621df..3d7b888d72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -25,7 +25,7 @@ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ @@ -554,7 +554,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case "json_tuple" => JsonTuple(expressions) case name => - UnresolvedGenerator(name, expressions) + UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions) } Generate( @@ -1033,12 +1033,12 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) val arguments = ctx.expression().asScala.map(expression) match { case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => - // Transform COUNT(*) into COUNT(1). Move this to analysis? + // Transform COUNT(*) into COUNT(1). Seq(Literal(1)) case expressions => expressions } - val function = UnresolvedFunction(name, arguments, isDistinct) + val function = UnresolvedFunction(visitFunctionName(ctx.qualifiedName), arguments, isDistinct) // Check if the function is evaluated in a windowed context. ctx.windowSpec match { @@ -1050,6 +1050,17 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } + /** + * Create a function database (optional) and name pair. + */ + protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = { + ctx.identifier().asScala.map(_.getText) match { + case Seq(db, fn) => FunctionIdentifier(fn, Option(db)) + case Seq(fn) => FunctionIdentifier(fn, None) + case other => throw new ParseException(s"Unsupported function name '${ctx.getText}'", ctx) + } + } + /** * Create a reference to a window frame, i.e. [[WindowSpecReference]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 27205c4587..1933be50b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -705,11 +705,11 @@ class SessionCatalogSuite extends SparkFunSuite { catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false) val arguments = Seq(Literal(1), Literal(2), Literal(3)) - assert(catalog.lookupFunction("temp1", arguments) === Literal(1)) - assert(catalog.lookupFunction("temp2", arguments) === Literal(3)) + assert(catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(1)) + assert(catalog.lookupFunction(FunctionIdentifier("temp2"), arguments) === Literal(3)) // Temporary function does not exist. intercept[AnalysisException] { - catalog.lookupFunction("temp3", arguments) + catalog.lookupFunction(FunctionIdentifier("temp3"), arguments) } val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) val info3 = new ExpressionInfo("tempFunc3", "temp1") @@ -719,7 +719,8 @@ class SessionCatalogSuite extends SparkFunSuite { } // Temporary function is overridden catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true) - assert(catalog.lookupFunction("temp1", arguments) === Literal(arguments.length)) + assert( + catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(arguments.length)) } test("drop function") { @@ -755,10 +756,10 @@ class SessionCatalogSuite extends SparkFunSuite { val tempFunc = (e: Seq[Expression]) => e.head catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false) val arguments = Seq(Literal(1), Literal(2), Literal(3)) - assert(catalog.lookupFunction("func1", arguments) === Literal(1)) + assert(catalog.lookupFunction(FunctionIdentifier("func1"), arguments) === Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) intercept[AnalysisException] { - catalog.lookupFunction("func1", arguments) + catalog.lookupFunction(FunctionIdentifier("func1"), arguments) } intercept[AnalysisException] { catalog.dropTempFunction("func1", ignoreIfNotExists = false) @@ -792,10 +793,11 @@ class SessionCatalogSuite extends SparkFunSuite { val info1 = new ExpressionInfo("tempFunc1", "func1") val tempFunc1 = (e: Seq[Expression]) => e.head catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) - assert(catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) + assert(catalog.lookupFunction( + FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) intercept[AnalysisException] { - catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3))) + catalog.lookupFunction(FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index f0ddc92bcd..5af3ea9c7a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.parser import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -199,7 +200,8 @@ class ExpressionParserSuite extends PlanTest { test("function expressions") { assertEqual("foo()", 'foo.function()) - assertEqual("foo.bar()", Symbol("foo.bar").function()) + assertEqual("foo.bar()", + UnresolvedFunction(FunctionIdentifier("bar", Some("foo")), Seq.empty, isDistinct = false)) assertEqual("foo(*)", 'foo.function(star())) assertEqual("count(*)", 'count.function(1)) assertEqual("foo(a, b)", 'foo.function('a, 'b)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 56c91a0fd5..5e896a33bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.{BooleanType, IntegerType} +import org.apache.spark.sql.types.IntegerType + class PlanParserSuite extends PlanTest { import CatalystSqlParser._ @@ -300,7 +302,7 @@ class PlanParserSuite extends PlanTest { // Unresolved generator. val expected = table("t") .generate( - UnresolvedGenerator("posexplode", Seq('x)), + UnresolvedGenerator(FunctionIdentifier("posexplode"), Seq('x)), join = true, outer = false, Some("posexpl"), -- cgit v1.2.3