aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala55
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala69
2 files changed, 104 insertions, 20 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index f30b5d8167..0d05d9808b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -25,6 +25,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery}
* An interface for looking up relations by name. Used by an [[Analyzer]].
*/
trait Catalog {
+
+ def caseSensitive: Boolean
+
def lookupRelation(
databaseName: Option[String],
tableName: String,
@@ -35,22 +38,44 @@ trait Catalog {
def unregisterTable(databaseName: Option[String], tableName: String): Unit
def unregisterAllTables(): Unit
+
+ protected def processDatabaseAndTableName(
+ databaseName: Option[String],
+ tableName: String): (Option[String], String) = {
+ if (!caseSensitive) {
+ (databaseName.map(_.toLowerCase), tableName.toLowerCase)
+ } else {
+ (databaseName, tableName)
+ }
+ }
+
+ protected def processDatabaseAndTableName(
+ databaseName: String,
+ tableName: String): (String, String) = {
+ if (!caseSensitive) {
+ (databaseName.toLowerCase, tableName.toLowerCase)
+ } else {
+ (databaseName, tableName)
+ }
+ }
}
-class SimpleCatalog extends Catalog {
+class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
val tables = new mutable.HashMap[String, LogicalPlan]()
override def registerTable(
databaseName: Option[String],
tableName: String,
plan: LogicalPlan): Unit = {
- tables += ((tableName, plan))
+ val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
+ tables += ((tblName, plan))
}
override def unregisterTable(
databaseName: Option[String],
tableName: String) = {
- tables -= tableName
+ val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
+ tables -= tblName
}
override def unregisterAllTables() = {
@@ -61,12 +86,13 @@ class SimpleCatalog extends Catalog {
databaseName: Option[String],
tableName: String,
alias: Option[String] = None): LogicalPlan = {
- val table = tables.get(tableName).getOrElse(sys.error(s"Table Not Found: $tableName"))
- val tableWithQualifiers = Subquery(tableName, table)
+ val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
+ val table = tables.get(tblName).getOrElse(sys.error(s"Table Not Found: $tableName"))
+ val tableWithQualifiers = Subquery(tblName, table)
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
// properly qualified with this alias.
- alias.map(a => Subquery(a.toLowerCase, tableWithQualifiers)).getOrElse(tableWithQualifiers)
+ alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
}
}
@@ -85,26 +111,28 @@ trait OverrideCatalog extends Catalog {
databaseName: Option[String],
tableName: String,
alias: Option[String] = None): LogicalPlan = {
-
- val overriddenTable = overrides.get((databaseName, tableName))
+ val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
+ val overriddenTable = overrides.get((dbName, tblName))
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
// properly qualified with this alias.
val withAlias =
- overriddenTable.map(r => alias.map(a => Subquery(a.toLowerCase, r)).getOrElse(r))
+ overriddenTable.map(r => alias.map(a => Subquery(a, r)).getOrElse(r))
- withAlias.getOrElse(super.lookupRelation(databaseName, tableName, alias))
+ withAlias.getOrElse(super.lookupRelation(dbName, tblName, alias))
}
override def registerTable(
databaseName: Option[String],
tableName: String,
plan: LogicalPlan): Unit = {
- overrides.put((databaseName, tableName), plan)
+ val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
+ overrides.put((dbName, tblName), plan)
}
override def unregisterTable(databaseName: Option[String], tableName: String): Unit = {
- overrides.remove((databaseName, tableName))
+ val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
+ overrides.remove((dbName, tblName))
}
override def unregisterAllTables(): Unit = {
@@ -117,6 +145,9 @@ trait OverrideCatalog extends Catalog {
* relations are already filled in and the analyser needs only to resolve attribute references.
*/
object EmptyCatalog extends Catalog {
+
+ val caseSensitive: Boolean = true
+
def lookupRelation(
databaseName: Option[String],
tableName: String,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index f14df81376..0a4fde3de7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -17,28 +17,81 @@
package org.apache.spark.sql.catalyst.analysis
-import org.scalatest.FunSuite
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.types.IntegerType
-/* Implicit conversions */
-import org.apache.spark.sql.catalyst.dsl.expressions._
+class AnalysisSuite extends FunSuite with BeforeAndAfter {
+ val caseSensitiveCatalog = new SimpleCatalog(true)
+ val caseInsensitiveCatalog = new SimpleCatalog(false)
+ val caseSensitiveAnalyze =
+ new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true)
+ val caseInsensitiveAnalyze =
+ new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false)
-class AnalysisSuite extends FunSuite {
- val analyze = SimpleAnalyzer
+ val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
- val testRelation = LocalRelation('a.int)
+ before {
+ caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation)
+ caseInsensitiveCatalog.registerTable(None, "TaBlE", testRelation)
+ }
test("analyze project") {
assert(
- analyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
+ caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
+ Project(testRelation.output, testRelation))
+
+ assert(
+ caseSensitiveAnalyze(
+ Project(Seq(UnresolvedAttribute("TbL.a")),
+ UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
+ Project(testRelation.output, testRelation))
+
+ val e = intercept[TreeNodeException[_]] {
+ caseSensitiveAnalyze(
+ Project(Seq(UnresolvedAttribute("tBl.a")),
+ UnresolvedRelation(None, "TaBlE", Some("TbL"))))
+ }
+ assert(e.getMessage().toLowerCase.contains("unresolved"))
+
+ assert(
+ caseInsensitiveAnalyze(
+ Project(Seq(UnresolvedAttribute("TbL.a")),
+ UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
Project(testRelation.output, testRelation))
+
+ assert(
+ caseInsensitiveAnalyze(
+ Project(Seq(UnresolvedAttribute("tBl.a")),
+ UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
+ Project(testRelation.output, testRelation))
+ }
+
+ test("resolve relations") {
+ val e = intercept[RuntimeException] {
+ caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None))
+ }
+ assert(e.getMessage === "Table Not Found: tAbLe")
+
+ assert(
+ caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
+ testRelation)
+
+ assert(
+ caseInsensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) ===
+ testRelation)
+
+ assert(
+ caseInsensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
+ testRelation)
}
test("throw errors for unresolved attributes during analysis") {
val e = intercept[TreeNodeException[_]] {
- analyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation))
+ caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation))
}
assert(e.getMessage().toLowerCase.contains("unresolved"))
}