aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSandeep Singh <sandeep@techaddict.me>2016-06-13 21:58:52 -0700
committerYin Huai <yhuai@databricks.com>2016-06-13 21:58:52 -0700
commit1842cdd4ee9f30b0a5f579e26ff5194e81e3634c (patch)
treef499dc95ea2f0765dd89b69857a47fcee8fae433
parentbaa3e633e18c47b12e79fe3ddc01fc8ec010f096 (diff)
downloadspark-1842cdd4ee9f30b0a5f579e26ff5194e81e3634c.tar.gz
spark-1842cdd4ee9f30b0a5f579e26ff5194e81e3634c.tar.bz2
spark-1842cdd4ee9f30b0a5f579e26ff5194e81e3634c.zip
[SPARK-15663][SQL] SparkSession.catalog.listFunctions shouldn't include the list of built-in functions
## What changes were proposed in this pull request? SparkSession.catalog.listFunctions currently returns all functions, including the list of built-in functions. This makes the method not as useful because anytime it is run the result set contains over 100 built-in functions. ## How was this patch tested? CatalogSuite Author: Sandeep Singh <sandeep@techaddict.me> Closes #13413 from techaddict/SPARK-15663.
-rw-r--r--python/pyspark/sql/tests.py12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala31
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala48
6 files changed, 67 insertions, 35 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 0d9dd5ea2a..e0acde6783 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1481,17 +1481,7 @@ class SQLTests(ReusedPySparkTestCase):
spark.sql("CREATE DATABASE some_db")
functions = dict((f.name, f) for f in spark.catalog.listFunctions())
functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default"))
- self.assertTrue(len(functions) > 200)
- self.assertTrue("+" in functions)
- self.assertTrue("like" in functions)
- self.assertTrue("month" in functions)
- self.assertTrue("to_unix_timestamp" in functions)
- self.assertTrue("current_database" in functions)
- self.assertEquals(functions["+"], Function(
- name="+",
- description=None,
- className="org.apache.spark.sql.catalyst.expressions.Add",
- isTemporary=True))
+ self.assertEquals(len(functions), 0)
self.assertEquals(functions, functionsDefault)
spark.catalog.registerFunction("temp_func", lambda x: str(x))
spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index a7388c71de..42a8faa412 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -396,6 +396,8 @@ object FunctionRegistry {
fr
}
+ val functionSet: Set[String] = builtin.listFunction().toSet
+
/** See usage above. */
private def expression[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 943d1071e2..1ec1bb1baf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -855,7 +855,8 @@ class SessionCatalog(
.map { f => FunctionIdentifier(f, Some(dbName)) }
val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern)
.map { f => FunctionIdentifier(f) }
- dbFunctions ++ loadedFunctions
+ (dbFunctions ++ loadedFunctions)
+ .filterNot(f => FunctionRegistry.functionSet.contains(f.funcName))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 89f8685099..545c1776b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -21,8 +21,10 @@ import java.math.MathContext
import java.sql.Timestamp
import org.apache.spark.AccumulatorSuite
+import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
-import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.catalog.{CatalogTestUtils, ExternalCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.expressions.{ExpressionDescription, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
@@ -58,15 +60,40 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("show functions") {
def getFunctions(pattern: String): Seq[Row] = {
- StringUtils.filterPattern(spark.sessionState.functionRegistry.listFunction(), pattern)
+ StringUtils.filterPattern(
+ spark.sessionState.catalog.listFunctions("default").map(_.funcName), pattern)
.map(Row(_))
}
+
+ def createFunction(names: Seq[String]): Unit = {
+ names.foreach { name =>
+ spark.udf.register(name, (arg1: Int, arg2: String) => arg2 + arg1)
+ }
+ }
+
+ def dropFunction(names: Seq[String]): Unit = {
+ names.foreach { name =>
+ spark.sessionState.catalog.dropTempFunction(name, false)
+ }
+ }
+
+ val functions = Array("ilog", "logi", "logii", "logiii", "crc32i", "cubei", "cume_disti",
+ "isize", "ispace", "to_datei", "date_addi", "current_datei")
+
+ assert(sql("SHOW functions").collect().isEmpty)
+
+ createFunction(functions)
+
checkAnswer(sql("SHOW functions"), getFunctions("*"))
+ assert(sql("SHOW functions").collect().size === functions.size)
+ assert(sql("SHOW functions").collect().toSet === functions.map(Row(_)).toSet)
+
Seq("^c*", "*e$", "log*", "*date*").foreach { pattern =>
// For the pattern part, only '*' and '|' are allowed as wildcards.
// For '*', we need to replace it to '.*'.
checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern))
}
+ dropFunction(functions)
}
test("describe functions") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
index aec0312c40..df817f863d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
@@ -174,8 +174,7 @@ class CatalogSuite
}
test("list functions") {
- assert(Set("+", "current_database", "window").subsetOf(
- spark.catalog.listFunctions().collect().map(_.name).toSet))
+ assert(spark.catalog.listFunctions().collect().isEmpty)
createFunction("my_func1")
createFunction("my_func2")
createTempFunction("my_temp_func")
@@ -192,8 +191,7 @@ class CatalogSuite
}
test("list functions with database") {
- assert(Set("+", "current_database", "window").subsetOf(
- spark.catalog.listFunctions("default").collect().map(_.name).toSet))
+ assert(spark.catalog.listFunctions("default").collect().isEmpty)
createDatabase("my_db1")
createDatabase("my_db2")
createFunction("my_func1", Some("my_db1"))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 8244ff4ce0..1a0eaa66c1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -187,28 +187,42 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
test("show functions") {
- val allBuiltinFunctions = FunctionRegistry.builtin.listFunction().toSet[String].toList.sorted
- // The TestContext is shared by all the test cases, some functions may be registered before
- // this, so we check that all the builtin functions are returned.
- val allFunctions = sql("SHOW functions").collect().map(r => r(0))
- allBuiltinFunctions.foreach { f =>
- assert(allFunctions.contains(f))
- }
withTempDatabase { db =>
- checkAnswer(sql("SHOW functions abs"), Row("abs"))
- checkAnswer(sql("SHOW functions 'abs'"), Row("abs"))
- checkAnswer(sql(s"SHOW functions $db.abs"), Row("abs"))
- checkAnswer(sql(s"SHOW functions `$db`.`abs`"), Row("abs"))
- checkAnswer(sql(s"SHOW functions `$db`.`abs`"), Row("abs"))
- checkAnswer(sql("SHOW functions `~`"), Row("~"))
+ def createFunction(names: Seq[String]): Unit = {
+ names.foreach { name =>
+ sql(
+ s"""
+ |CREATE TEMPORARY FUNCTION $name
+ |AS '${classOf[PairUDF].getName}'
+ """.stripMargin)
+ }
+ }
+ def dropFunction(names: Seq[String]): Unit = {
+ names.foreach { name =>
+ sql(s"DROP TEMPORARY FUNCTION $name")
+ }
+ }
+ createFunction(Seq("temp_abs", "temp_weekofyear", "temp_sha", "temp_sha1", "temp_sha2"))
+
+ checkAnswer(sql("SHOW functions temp_abs"), Row("temp_abs"))
+ checkAnswer(sql("SHOW functions 'temp_abs'"), Row("temp_abs"))
+ checkAnswer(sql(s"SHOW functions $db.temp_abs"), Row("temp_abs"))
+ checkAnswer(sql(s"SHOW functions `$db`.`temp_abs`"), Row("temp_abs"))
+ checkAnswer(sql(s"SHOW functions `$db`.`temp_abs`"), Row("temp_abs"))
checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil)
- checkAnswer(sql("SHOW functions `weekofyea*`"), Row("weekofyear"))
+ checkAnswer(sql("SHOW functions `temp_weekofyea*`"), Row("temp_weekofyear"))
+
// this probably will failed if we add more function with `sha` prefixing.
- checkAnswer(sql("SHOW functions `sha*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil)
+ checkAnswer(
+ sql("SHOW functions `temp_sha*`"),
+ List(Row("temp_sha"), Row("temp_sha1"), Row("temp_sha2")))
+
// Test '|' for alternation.
checkAnswer(
- sql("SHOW functions 'sha*|weekofyea*'"),
- Row("sha") :: Row("sha1") :: Row("sha2") :: Row("weekofyear") :: Nil)
+ sql("SHOW functions 'temp_sha*|temp_weekofyea*'"),
+ List(Row("temp_sha"), Row("temp_sha1"), Row("temp_sha2"), Row("temp_weekofyear")))
+
+ dropFunction(Seq("temp_abs", "temp_weekofyear", "temp_sha", "temp_sha1", "temp_sha2"))
}
}