aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-02-12 18:08:01 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-12 18:08:01 -0800
commit1d0596a16e1d3add2631f5d8169aeec2876a1362 (patch)
treed691a0e0e370a13f8cd35ec7925ebddb2e159ff5 /sql
parentc025a468826e9b9f62032e207daa9d42d9dba3ca (diff)
downloadspark-1d0596a16e1d3add2631f5d8169aeec2876a1362.tar.gz
spark-1d0596a16e1d3add2631f5d8169aeec2876a1362.tar.bz2
spark-1d0596a16e1d3add2631f5d8169aeec2876a1362.zip
[SPARK-3299][SQL]Public API in SQLContext to list tables
https://issues.apache.org/jira/browse/SPARK-3299 Author: Yin Huai <yhuai@databricks.com> Closes #4547 from yhuai/tables and squashes the following commits: 6c8f92e [Yin Huai] Add tableNames. acbb281 [Yin Huai] Update Python test. 7793dcb [Yin Huai] Fix scala test. 572870d [Yin Huai] Address comments. aba2e88 [Yin Huai] Format. 12c86df [Yin Huai] Add tables() to SQLContext to return a DataFrame containing existing tables.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala36
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala76
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala5
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala77
5 files changed, 231 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index df8d03b86c..f57eab2460 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -34,6 +34,12 @@ trait Catalog {
tableIdentifier: Seq[String],
alias: Option[String] = None): LogicalPlan
+ /**
+ * Returns tuples of (tableName, isTemporary) for all tables in the given database.
+ * isTemporary is a Boolean value indicates if a table is a temporary or not.
+ */
+ def getTables(databaseName: Option[String]): Seq[(String, Boolean)]
+
def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit
def unregisterTable(tableIdentifier: Seq[String]): Unit
@@ -101,6 +107,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
// properly qualified with this alias.
alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
}
+
+ override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
+ tables.map {
+ case (name, _) => (name, true)
+ }.toSeq
+ }
}
/**
@@ -137,6 +149,27 @@ trait OverrideCatalog extends Catalog {
withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias))
}
+ abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
+ val dbName = if (!caseSensitive) {
+ if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None
+ } else {
+ databaseName
+ }
+
+ val temporaryTables = overrides.filter {
+ // If a temporary table does not have an associated database, we should return its name.
+ case ((None, _), _) => true
+ // If a temporary table does have an associated database, we should return it if the database
+ // matches the given database name.
+ case ((db: Some[String], _), _) if db == dbName => true
+ case _ => false
+ }.map {
+ case ((_, tableName), _) => (tableName, true)
+ }.toSeq
+
+ temporaryTables ++ super.getTables(databaseName)
+ }
+
override def registerTable(
tableIdentifier: Seq[String],
plan: LogicalPlan): Unit = {
@@ -172,6 +205,10 @@ object EmptyCatalog extends Catalog {
throw new UnsupportedOperationException
}
+ override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
+ throw new UnsupportedOperationException
+ }
+
def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = {
throw new UnsupportedOperationException
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 8aae222acd..0f8af75fe7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -774,6 +774,42 @@ class SQLContext(@transient val sparkContext: SparkContext)
def table(tableName: String): DataFrame =
DataFrame(this, catalog.lookupRelation(Seq(tableName)))
+ /**
+ * Returns a [[DataFrame]] containing names of existing tables in the given database.
+ * The returned DataFrame has two columns, tableName and isTemporary (a column with BooleanType
+ * indicating if a table is a temporary one or not).
+ */
+ def tables(): DataFrame = {
+ createDataFrame(catalog.getTables(None)).toDataFrame("tableName", "isTemporary")
+ }
+
+ /**
+ * Returns a [[DataFrame]] containing names of existing tables in the current database.
+ * The returned DataFrame has two columns, tableName and isTemporary (a column with BooleanType
+ * indicating if a table is a temporary one or not).
+ */
+ def tables(databaseName: String): DataFrame = {
+ createDataFrame(catalog.getTables(Some(databaseName))).toDataFrame("tableName", "isTemporary")
+ }
+
+ /**
+ * Returns an array of names of tables in the current database.
+ */
+ def tableNames(): Array[String] = {
+ catalog.getTables(None).map {
+ case (tableName, _) => tableName
+ }.toArray
+ }
+
+ /**
+ * Returns an array of names of tables in the given database.
+ */
+ def tableNames(databaseName: String): Array[String] = {
+ catalog.getTables(Some(databaseName)).map {
+ case (tableName, _) => tableName
+ }.toArray
+ }
+
protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext: SparkContext = self.sparkContext
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
new file mode 100644
index 0000000000..5fc35349e1
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -0,0 +1,76 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
+
+class ListTablesSuite extends QueryTest with BeforeAndAfter {
+
+ import org.apache.spark.sql.test.TestSQLContext.implicits._
+
+ val df =
+ sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")
+
+ before {
+ df.registerTempTable("ListTablesSuiteTable")
+ }
+
+ after {
+ catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+ }
+
+ test("get all tables") {
+ checkAnswer(
+ tables().filter("tableName = 'ListTablesSuiteTable'"),
+ Row("ListTablesSuiteTable", true))
+
+ catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+ assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
+ }
+
+ test("getting all Tables with a database name has no impact on returned table names") {
+ checkAnswer(
+ tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
+ Row("ListTablesSuiteTable", true))
+
+ catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+ assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
+ }
+
+ test("query the returned DataFrame of tables") {
+ val tableDF = tables()
+ val schema = StructType(
+ StructField("tableName", StringType, true) ::
+ StructField("isTemporary", BooleanType, false) :: Nil)
+ assert(schema === tableDF.schema)
+
+ tableDF.registerTempTable("tables")
+ checkAnswer(
+ sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
+ Row(true, "ListTablesSuiteTable")
+ )
+ checkAnswer(
+ tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
+ Row("tables", true))
+ dropTempTable("tables")
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index c78369d12c..eb1ee54247 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -198,6 +198,11 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
}
}
+ override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
+ val dbName = databaseName.getOrElse(hive.sessionState.getCurrentDatabase)
+ client.getAllTables(dbName).map(tableName => (tableName, false))
+ }
+
/**
* Create table with specified database, table name, table description and schema
* @param databaseName Database Name
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
new file mode 100644
index 0000000000..068aa03330
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
@@ -0,0 +1,77 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.hive
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.Row
+
+class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
+
+ import org.apache.spark.sql.hive.test.TestHive.implicits._
+
+ val df =
+ sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")
+
+ override def beforeAll(): Unit = {
+ // The catalog in HiveContext is a case insensitive one.
+ catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan)
+ catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan)
+ sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)")
+ sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB")
+ sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)")
+ }
+
+ override def afterAll(): Unit = {
+ catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+ catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"))
+ sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable")
+ sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable")
+ sql("DROP DATABASE IF EXISTS ListTablesSuiteDB")
+ }
+
+ test("get all tables of current database") {
+ val allTables = tables()
+ // We are using default DB.
+ checkAnswer(
+ allTables.filter("tableName = 'listtablessuitetable'"),
+ Row("listtablessuitetable", true))
+ assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0)
+ checkAnswer(
+ allTables.filter("tableName = 'hivelisttablessuitetable'"),
+ Row("hivelisttablessuitetable", false))
+ assert(allTables.filter("tableName = 'hiveindblisttablessuitetable'").count() === 0)
+ }
+
+ test("getting all tables with a database name") {
+ val allTables = tables("ListTablesSuiteDB")
+ checkAnswer(
+ allTables.filter("tableName = 'listtablessuitetable'"),
+ Row("listtablessuitetable", true))
+ checkAnswer(
+ allTables.filter("tableName = 'indblisttablessuitetable'"),
+ Row("indblisttablessuitetable", true))
+ assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0)
+ checkAnswer(
+ allTables.filter("tableName = 'hiveindblisttablessuitetable'"),
+ Row("hiveindblisttablessuitetable", false))
+ }
+}