From 1842cdd4ee9f30b0a5f579e26ff5194e81e3634c Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Mon, 13 Jun 2016 21:58:52 -0700 Subject: [SPARK-15663][SQL] SparkSession.catalog.listFunctions shouldn't include the list of built-in functions ## What changes were proposed in this pull request? SparkSession.catalog.listFunctions currently returns all functions, including the list of built-in functions. This makes the method not as useful because anytime it is run the result set contains over 100 built-in functions. ## How was this patch tested? CatalogSuite Author: Sandeep Singh Closes #13413 from techaddict/SPARK-15663. --- python/pyspark/sql/tests.py | 12 +----- .../sql/catalyst/analysis/FunctionRegistry.scala | 2 + .../sql/catalyst/catalog/SessionCatalog.scala | 3 +- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 31 +++++++++++++- .../apache/spark/sql/internal/CatalogSuite.scala | 6 +-- .../spark/sql/hive/execution/SQLQuerySuite.scala | 48 ++++++++++++++-------- 6 files changed, 67 insertions(+), 35 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0d9dd5ea2a..e0acde6783 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1481,17 +1481,7 @@ class SQLTests(ReusedPySparkTestCase): spark.sql("CREATE DATABASE some_db") functions = dict((f.name, f) for f in spark.catalog.listFunctions()) functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default")) - self.assertTrue(len(functions) > 200) - self.assertTrue("+" in functions) - self.assertTrue("like" in functions) - self.assertTrue("month" in functions) - self.assertTrue("to_unix_timestamp" in functions) - self.assertTrue("current_database" in functions) - self.assertEquals(functions["+"], Function( - name="+", - description=None, - className="org.apache.spark.sql.catalyst.expressions.Add", - isTemporary=True)) + self.assertEquals(len(functions), 0) self.assertEquals(functions, functionsDefault) spark.catalog.registerFunction("temp_func", lambda x: str(x)) spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a7388c71de..42a8faa412 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -396,6 +396,8 @@ object FunctionRegistry { fr } + val functionSet: Set[String] = builtin.listFunction().toSet + /** See usage above. */ private def expression[T <: Expression](name: String) (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 943d1071e2..1ec1bb1baf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -855,7 +855,8 @@ class SessionCatalog( .map { f => FunctionIdentifier(f, Some(dbName)) } val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern) .map { f => FunctionIdentifier(f) } - dbFunctions ++ loadedFunctions + (dbFunctions ++ loadedFunctions) + .filterNot(f => FunctionRegistry.functionSet.contains(f.funcName)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 89f8685099..545c1776b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,8 +21,10 @@ import java.math.MathContext import java.sql.Timestamp import org.apache.spark.AccumulatorSuite +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.catalog.{CatalogTestUtils, ExternalCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.{ExpressionDescription, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate @@ -58,15 +60,40 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("show functions") { def getFunctions(pattern: String): Seq[Row] = { - StringUtils.filterPattern(spark.sessionState.functionRegistry.listFunction(), pattern) + StringUtils.filterPattern( + spark.sessionState.catalog.listFunctions("default").map(_.funcName), pattern) .map(Row(_)) } + + def createFunction(names: Seq[String]): Unit = { + names.foreach { name => + spark.udf.register(name, (arg1: Int, arg2: String) => arg2 + arg1) + } + } + + def dropFunction(names: Seq[String]): Unit = { + names.foreach { name => + spark.sessionState.catalog.dropTempFunction(name, false) + } + } + + val functions = Array("ilog", "logi", "logii", "logiii", "crc32i", "cubei", "cume_disti", + "isize", "ispace", "to_datei", "date_addi", "current_datei") + + assert(sql("SHOW functions").collect().isEmpty) + + createFunction(functions) + checkAnswer(sql("SHOW functions"), getFunctions("*")) + assert(sql("SHOW functions").collect().size === functions.size) + assert(sql("SHOW functions").collect().toSet === functions.map(Row(_)).toSet) + Seq("^c*", "*e$", "log*", "*date*").foreach { pattern => // For the pattern part, only '*' and '|' are allowed as wildcards. // For '*', we need to replace it to '.*'. checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern)) } + dropFunction(functions) } test("describe functions") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index aec0312c40..df817f863d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -174,8 +174,7 @@ class CatalogSuite } test("list functions") { - assert(Set("+", "current_database", "window").subsetOf( - spark.catalog.listFunctions().collect().map(_.name).toSet)) + assert(spark.catalog.listFunctions().collect().isEmpty) createFunction("my_func1") createFunction("my_func2") createTempFunction("my_temp_func") @@ -192,8 +191,7 @@ class CatalogSuite } test("list functions with database") { - assert(Set("+", "current_database", "window").subsetOf( - spark.catalog.listFunctions("default").collect().map(_.name).toSet)) + assert(spark.catalog.listFunctions("default").collect().isEmpty) createDatabase("my_db1") createDatabase("my_db2") createFunction("my_func1", Some("my_db1")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 8244ff4ce0..1a0eaa66c1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -187,28 +187,42 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("show functions") { - val allBuiltinFunctions = FunctionRegistry.builtin.listFunction().toSet[String].toList.sorted - // The TestContext is shared by all the test cases, some functions may be registered before - // this, so we check that all the builtin functions are returned. - val allFunctions = sql("SHOW functions").collect().map(r => r(0)) - allBuiltinFunctions.foreach { f => - assert(allFunctions.contains(f)) - } withTempDatabase { db => - checkAnswer(sql("SHOW functions abs"), Row("abs")) - checkAnswer(sql("SHOW functions 'abs'"), Row("abs")) - checkAnswer(sql(s"SHOW functions $db.abs"), Row("abs")) - checkAnswer(sql(s"SHOW functions `$db`.`abs`"), Row("abs")) - checkAnswer(sql(s"SHOW functions `$db`.`abs`"), Row("abs")) - checkAnswer(sql("SHOW functions `~`"), Row("~")) + def createFunction(names: Seq[String]): Unit = { + names.foreach { name => + sql( + s""" + |CREATE TEMPORARY FUNCTION $name + |AS '${classOf[PairUDF].getName}' + """.stripMargin) + } + } + def dropFunction(names: Seq[String]): Unit = { + names.foreach { name => + sql(s"DROP TEMPORARY FUNCTION $name") + } + } + createFunction(Seq("temp_abs", "temp_weekofyear", "temp_sha", "temp_sha1", "temp_sha2")) + + checkAnswer(sql("SHOW functions temp_abs"), Row("temp_abs")) + checkAnswer(sql("SHOW functions 'temp_abs'"), Row("temp_abs")) + checkAnswer(sql(s"SHOW functions $db.temp_abs"), Row("temp_abs")) + checkAnswer(sql(s"SHOW functions `$db`.`temp_abs`"), Row("temp_abs")) + checkAnswer(sql(s"SHOW functions `$db`.`temp_abs`"), Row("temp_abs")) checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil) - checkAnswer(sql("SHOW functions `weekofyea*`"), Row("weekofyear")) + checkAnswer(sql("SHOW functions `temp_weekofyea*`"), Row("temp_weekofyear")) + // this probably will failed if we add more function with `sha` prefixing. - checkAnswer(sql("SHOW functions `sha*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) + checkAnswer( + sql("SHOW functions `temp_sha*`"), + List(Row("temp_sha"), Row("temp_sha1"), Row("temp_sha2"))) + // Test '|' for alternation. checkAnswer( - sql("SHOW functions 'sha*|weekofyea*'"), - Row("sha") :: Row("sha1") :: Row("sha2") :: Row("weekofyear") :: Nil) + sql("SHOW functions 'temp_sha*|temp_weekofyea*'"), + List(Row("temp_sha"), Row("temp_sha1"), Row("temp_sha2"), Row("temp_weekofyear"))) + + dropFunction(Seq("temp_abs", "temp_weekofyear", "temp_sha", "temp_sha1", "temp_sha2")) } } -- cgit v1.2.3