aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala13
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala45
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala11
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala16
8 files changed, 107 insertions, 20 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 4a95240741..574d96d929 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -40,7 +40,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
// TODO: pass this in as a parameter.
val fixedPoint = FixedPoint(100)
- val batches: Seq[Batch] = Seq(
+ /**
+ * Override to provide additional rules for the "Resolution" batch.
+ */
+ val extendedRules: Seq[Rule[LogicalPlan]] = Nil
+
+ lazy val batches: Seq[Batch] = Seq(
Batch("MultiInstanceRelations", Once,
NewRelationInstances),
Batch("CaseInsensitiveAttributeReferences", Once,
@@ -54,8 +59,9 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
StarExpansion ::
ResolveFunctions ::
GlobalAggregates ::
- UnresolvedHavingClauseAttributes ::
- typeCoercionRules :_*),
+ UnresolvedHavingClauseAttributes ::
+ typeCoercionRules ++
+ extendedRules : _*),
Batch("Check Analysis", Once,
CheckResolution),
Batch("AnalysisOperators", fixedPoint,
@@ -63,7 +69,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
)
/**
- * Makes sure all attributes have been resolved.
+ * Makes sure all attributes and logical plans have been resolved.
*/
object CheckResolution extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
@@ -71,6 +77,13 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
case p if p.expressions.exists(!_.resolved) =>
throw new TreeNodeException(p,
s"Unresolved attributes: ${p.expressions.filterNot(_.resolved).mkString(",")}")
+ case p if !p.resolved && p.childrenResolved =>
+ throw new TreeNodeException(p, "Unresolved plan found")
+ } match {
+ // As a backstop, use the root node to check that the entire plan tree is resolved.
+ case p if !p.resolved =>
+ throw new TreeNodeException(p, "Unresolved plan in tree")
+ case p => p
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index bd8131c9af..79e5283e86 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -286,6 +286,10 @@ trait HiveTypeCoercion {
// If the data type is not boolean and is being cast boolean, turn it into a comparison
// with the numeric value, i.e. x != 0. This will coerce the type into numeric type.
case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0)))
+ // Stringify boolean if casting to StringType.
+ // TODO Ensure true/false string letter casing is consistent with Hive in all cases.
+ case Cast(e, StringType) if e.dataType == BooleanType =>
+ If(e, Literal("true"), Literal("false"))
// Turn true into 1, and false into 0 if casting boolean into other types.
case Cast(e, dataType) if e.dataType == BooleanType =>
Cast(If(e, Literal(1), Literal(0)), dataType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index bae491f07c..ede431ad4a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -58,7 +58,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
/**
* Returns true if this expression and all its children have been resolved to a specific schema
- * and false if it is still contains any unresolved placeholders. Implementations of LogicalPlan
+ * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan
* can override this (e.g.
* [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]]
* should return `false`).
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 0a4fde3de7..5809a108ff 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -93,6 +93,17 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
val e = intercept[TreeNodeException[_]] {
caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation))
}
- assert(e.getMessage().toLowerCase.contains("unresolved"))
+ assert(e.getMessage().toLowerCase.contains("unresolved attribute"))
+ }
+
+ test("throw errors for unresolved plans during analysis") {
+ case class UnresolvedTestPlan() extends LeafNode {
+ override lazy val resolved = false
+ override def output = Nil
+ }
+ val e = intercept[TreeNodeException[_]] {
+ caseSensitiveAnalyze(UnresolvedTestPlan())
+ }
+ assert(e.getMessage().toLowerCase.contains("unresolved plan"))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index ba8b853b6f..baeb9b0cf5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.analysis
import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.types._
class HiveTypeCoercionSuite extends FunSuite {
@@ -84,4 +86,17 @@ class HiveTypeCoercionSuite extends FunSuite {
widenTest(StringType, MapType(IntegerType, StringType, true), None)
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
}
+
+ test("boolean casts") {
+ val booleanCasts = new HiveTypeCoercion { }.BooleanCasts
+ def ruleTest(initial: Expression, transformed: Expression) {
+ val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
+ assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) ==
+ Project(Seq(Alias(transformed, "a")()), testRelation))
+ }
+ // Remove superflous boolean -> boolean casts.
+ ruleTest(Cast(Literal(true), BooleanType), Literal(true))
+ // Stringify boolean when casting to string.
+ ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false")))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 514ac543df..67563b6c55 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test._
import org.scalatest.BeforeAndAfterAll
@@ -477,18 +478,48 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
(3, null)))
}
- test("EXCEPT") {
+ test("UNION") {
+ checkAnswer(
+ sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"),
+ (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") ::
+ (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil)
+ checkAnswer(
+ sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"),
+ (1, "a") :: (2, "b") :: (3, "c") :: (4, "d") :: Nil)
+ checkAnswer(
+ sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"),
+ (1, "a") :: (1, "a") :: (2, "b") :: (2, "b") :: (3, "c") :: (3, "c") ::
+ (4, "d") :: (4, "d") :: Nil)
+ }
+ test("UNION with column mismatches") {
+ // Column name mismatches are allowed.
+ checkAnswer(
+ sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"),
+ (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") ::
+ (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil)
+ // Column type mismatches are not allowed, forcing a type coercion.
checkAnswer(
- sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData "),
+ sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"),
+ ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Tuple1(_)))
+ // Column type mismatches where a coercion is not possible, in this case between integer
+ // and array types, trigger a TreeNodeException.
+ intercept[TreeNodeException[_]] {
+ sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect()
+ }
+ }
+
+ test("EXCEPT") {
+ checkAnswer(
+ sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"),
(1, "a") ::
(2, "b") ::
(3, "c") ::
(4, "d") :: Nil)
checkAnswer(
- sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData "), Nil)
+ sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil)
checkAnswer(
- sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData "), Nil)
+ sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil)
}
test("INTERSECT") {
@@ -634,6 +665,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"),
Seq()
)
+ }
+ test("cast boolean to string") {
+ // TODO Ensure true/false string letter casing is consistent with Hive in all cases.
+ checkAnswer(
+ sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"),
+ ("true", "false") :: Nil)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index ced8397972..e0be09e679 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -262,7 +262,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
/* An analyzer that uses the Hive metastore. */
@transient
override protected[sql] lazy val analyzer =
- new Analyzer(catalog, functionRegistry, caseSensitive = false)
+ new Analyzer(catalog, functionRegistry, caseSensitive = false) {
+ override val extendedRules =
+ catalog.CreateTables ::
+ catalog.PreInsertionCasts ::
+ ExtractPythonUdfs ::
+ Nil
+ }
/**
* Runs the specified SQL query using Hive.
@@ -353,9 +359,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
/** Extends QueryExecution with hive specific features. */
protected[sql] abstract class QueryExecution extends super.QueryExecution {
- // TODO: Create mixin for the analyzer instead of overriding things here.
- override lazy val optimizedPlan =
- optimizer(ExtractPythonUdfs(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))))
override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy())
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 6571c35499..dfa2a7a9d2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -109,6 +109,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
*/
object CreateTables extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ // Wait until children are resolved.
+ case p: LogicalPlan if !p.childrenResolved => p
+
case InsertIntoCreatedTable(db, tableName, child) =>
val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase)
@@ -116,8 +119,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
createTable(databaseName, tblName, child.output)
InsertIntoTable(
- EliminateAnalysisOperators(
- lookupRelation(Some(databaseName), tblName, None)),
+ lookupRelation(Some(databaseName), tblName, None),
Map.empty,
child,
overwrite = false)
@@ -130,15 +132,17 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
*/
object PreInsertionCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
- // Wait until children are resolved
+ // Wait until children are resolved.
case p: LogicalPlan if !p.childrenResolved => p
- case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) =>
+ case p @ InsertIntoTable(
+ LowerCaseSchema(table: MetastoreRelation), _, child, _) =>
castChildOutput(p, table, child)
case p @ logical.InsertIntoTable(
- InMemoryRelation(_, _, _,
- HiveTableScan(_, table, _)), _, child, _) =>
+ LowerCaseSchema(
+ InMemoryRelation(_, _, _,
+ HiveTableScan(_, table, _))), _, child, _) =>
castChildOutput(p, table, child)
}