aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala32
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala33
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala5
5 files changed, 71 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index d88b5ffc05..c0ebb2b1fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -22,7 +22,7 @@ import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
@@ -253,9 +253,27 @@ class SessionCatalog(
def getTableMetadata(name: TableIdentifier): CatalogTable = {
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
val table = formatTableName(name.table)
- requireDbExists(db)
- requireTableExists(TableIdentifier(table, Some(db)))
- externalCatalog.getTable(db, table)
+ val tid = TableIdentifier(table)
+ if (isTemporaryTable(name)) {
+ CatalogTable(
+ identifier = tid,
+ tableType = CatalogTableType.VIEW,
+ storage = CatalogStorageFormat.empty,
+ schema = tempTables(table).output.map { c =>
+ CatalogColumn(
+ name = c.name,
+ dataType = c.dataType.catalogString,
+ nullable = c.nullable,
+ comment = Option(c.name)
+ )
+ },
+ properties = Map(),
+ viewText = None)
+ } else {
+ requireDbExists(db)
+ requireTableExists(TableIdentifier(table, Some(db)))
+ externalCatalog.getTable(db, table)
+ }
}
/**
@@ -432,10 +450,10 @@ class SessionCatalog(
def tableExists(name: TableIdentifier): Boolean = synchronized {
val db = formatDatabaseName(name.database.getOrElse(currentDb))
val table = formatTableName(name.table)
- if (name.database.isDefined || !tempTables.contains(table)) {
- externalCatalog.tableExists(db, table)
+ if (isTemporaryTable(name)) {
+ true
} else {
- true // it's a temporary table
+ externalCatalog.tableExists(db, table)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index 05eb302c3c..adce5df81c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -432,6 +432,39 @@ class SessionCatalogSuite extends SparkFunSuite {
assert(catalog.tableExists(TableIdentifier("tbl3")))
}
+ test("tableExists on temporary views") {
+ val catalog = new SessionCatalog(newBasicCatalog())
+ val tempTable = Range(1, 10, 2, 10)
+ assert(!catalog.tableExists(TableIdentifier("view1")))
+ assert(!catalog.tableExists(TableIdentifier("view1", Some("default"))))
+ catalog.createTempView("view1", tempTable, overrideIfExists = false)
+ assert(catalog.tableExists(TableIdentifier("view1")))
+ assert(!catalog.tableExists(TableIdentifier("view1", Some("default"))))
+ }
+
+ test("getTableMetadata on temporary views") {
+ val catalog = new SessionCatalog(newBasicCatalog())
+ val tempTable = Range(1, 10, 2, 10)
+ val m = intercept[AnalysisException] {
+ catalog.getTableMetadata(TableIdentifier("view1"))
+ }.getMessage
+ assert(m.contains("Table or view 'view1' not found in database 'default'"))
+
+ val m2 = intercept[AnalysisException] {
+ catalog.getTableMetadata(TableIdentifier("view1", Some("default")))
+ }.getMessage
+ assert(m2.contains("Table or view 'view1' not found in database 'default'"))
+
+ catalog.createTempView("view1", tempTable, overrideIfExists = false)
+ assert(catalog.getTableMetadata(TableIdentifier("view1")).identifier.table == "view1")
+ assert(catalog.getTableMetadata(TableIdentifier("view1")).schema(0).name == "id")
+
+ val m3 = intercept[AnalysisException] {
+ catalog.getTableMetadata(TableIdentifier("view1", Some("default")))
+ }.getMessage
+ assert(m3.contains("Table or view 'view1' not found in database 'default'"))
+ }
+
test("list tables without pattern") {
val catalog = new SessionCatalog(newBasicCatalog())
val tempTable = Range(1, 10, 2, 10)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
index 91ed9b3258..1aed245fdd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
@@ -85,7 +85,8 @@ abstract class Catalog {
def listFunctions(dbName: String): Dataset[Function]
/**
- * Returns a list of columns for the given table in the current database.
+ * Returns a list of columns for the given table in the current database or
+ * the given temporary table.
*
* @since 2.0.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index 44babcc93a..a6ae6fe2aa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -138,7 +138,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
*/
@throws[AnalysisException]("table does not exist")
override def listColumns(tableName: String): Dataset[Column] = {
- listColumns(currentDatabase, tableName)
+ listColumns(TableIdentifier(tableName, None))
}
/**
@@ -147,7 +147,11 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
@throws[AnalysisException]("database or table does not exist")
override def listColumns(dbName: String, tableName: String): Dataset[Column] = {
requireTableExists(dbName, tableName)
- val tableMetadata = sessionCatalog.getTableMetadata(TableIdentifier(tableName, Some(dbName)))
+ listColumns(TableIdentifier(tableName, Some(dbName)))
+ }
+
+ private def listColumns(tableIdentifier: TableIdentifier): Dataset[Column] = {
+ val tableMetadata = sessionCatalog.getTableMetadata(tableIdentifier)
val partitionColumnNames = tableMetadata.partitionColumnNames.toSet
val bucketColumnNames = tableMetadata.bucketColumnNames.toSet
val columns = tableMetadata.schema.map { c =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
index d862e4cfa9..d75df56dd6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
@@ -234,6 +234,11 @@ class CatalogSuite
testListColumns("tab1", dbName = None)
}
+ test("list columns in temporary table") {
+ createTempTable("temp1")
+ spark.catalog.listColumns("temp1")
+ }
+
test("list columns in database") {
createDatabase("db1")
createTable("tab1", Some("db1"))