aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-02-23 17:34:54 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-23 17:34:54 -0800
commit1ed57086d402c38d95cda6c3d9d7aea806609bf9 (patch)
treefb92a551881535edd2bb9c8c234d901d81e10876
parent48376bfe9c97bf31279918def6c6615849c88f4d (diff)
downloadspark-1ed57086d402c38d95cda6c3d9d7aea806609bf9.tar.gz
spark-1ed57086d402c38d95cda6c3d9d7aea806609bf9.tar.bz2
spark-1ed57086d402c38d95cda6c3d9d7aea806609bf9.zip
[SPARK-5873][SQL] Allow viewing of partially analyzed plans in queryExecution
Author: Michael Armbrust <michael@databricks.com> Closes #4684 from marmbrus/explainAnalysis and squashes the following commits: afbaa19 [Michael Armbrust] fix python d93278c [Michael Armbrust] fix hive e5fa0a4 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explainAnalysis 52119f2 [Michael Armbrust] more tests 82a5431 [Michael Armbrust] fix tests 25753d2 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explainAnalysis aee1e6a [Michael Armbrust] fix hive b23a844 [Michael Armbrust] newline de8dc51 [Michael Armbrust] more comments acf620a [Michael Armbrust] [SPARK-5873][SQL] Show partially analyzed plans in query execution
-rw-r--r--python/pyspark/sql/context.py30
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala83
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala105
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala1
12 files changed, 164 insertions, 126 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 313f15e6d9..125933c9d3 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -267,20 +267,20 @@ class SQLContext(object):
... StructField("byte2", ByteType(), False),
... StructField("short1", ShortType(), False),
... StructField("short2", ShortType(), False),
- ... StructField("int", IntegerType(), False),
- ... StructField("float", FloatType(), False),
- ... StructField("date", DateType(), False),
- ... StructField("time", TimestampType(), False),
- ... StructField("map",
+ ... StructField("int1", IntegerType(), False),
+ ... StructField("float1", FloatType(), False),
+ ... StructField("date1", DateType(), False),
+ ... StructField("time1", TimestampType(), False),
+ ... StructField("map1",
... MapType(StringType(), IntegerType(), False), False),
- ... StructField("struct",
+ ... StructField("struct1",
... StructType([StructField("b", ShortType(), False)]), False),
- ... StructField("list", ArrayType(ByteType(), False), False),
- ... StructField("null", DoubleType(), True)])
+ ... StructField("list1", ArrayType(ByteType(), False), False),
+ ... StructField("null1", DoubleType(), True)])
>>> df = sqlCtx.applySchema(rdd, schema)
>>> results = df.map(
- ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
- ... x.time, x.map["a"], x.struct.b, x.list, x.null))
+ ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
+ ... x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
>>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
(127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
@@ -288,20 +288,20 @@ class SQLContext(object):
>>> df.registerTempTable("table2")
>>> sqlCtx.sql(
... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
- ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
- ... "float + 1.5 as float FROM table2").collect()
- [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)]
+ ... "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
+ ... "float1 + 1.5 as float1 FROM table2").collect()
+ [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int1=2147483646, float1=2.5)]
>>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
>>> rdd = sc.parallelize([(127, -32768, 1.0,
... datetime(2010, 1, 1, 1, 1, 1),
... {"a": 1}, (2,), [1, 2, 3])])
- >>> abstract = "byte short float time map{} struct(b) list[]"
+ >>> abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
>>> schema = _parse_schema_abstract(abstract)
>>> typedSchema = _infer_schema_type(rdd.first(), schema)
>>> df = sqlCtx.applySchema(rdd, typedSchema)
>>> df.collect()
- [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
+ [Row(byte1=127, short1=-32768, float1=1.0, time1=..., list1=[1, 2, 3])]
"""
if isinstance(rdd, DataFrame):
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 124f083669..b16aff99af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -78,6 +78,7 @@ class SqlParser extends AbstractSparkSQLParser {
protected val IF = Keyword("IF")
protected val IN = Keyword("IN")
protected val INNER = Keyword("INNER")
+ protected val INT = Keyword("INT")
protected val INSERT = Keyword("INSERT")
protected val INTERSECT = Keyword("INTERSECT")
protected val INTO = Keyword("INTO")
@@ -394,6 +395,7 @@ class SqlParser extends AbstractSparkSQLParser {
| fixedDecimalType
| DECIMAL ^^^ DecimalType.Unlimited
| DATE ^^^ DateType
+ | INT ^^^ IntegerType
)
protected lazy val fixedDecimalType: Parser[DataType] =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index fc37b8cde0..e4e542562f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -52,12 +52,6 @@ class Analyzer(catalog: Catalog,
*/
val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil
- /**
- * Override to provide additional rules for the "Check Analysis" batch.
- * These rules will be evaluated after our built-in check rules.
- */
- val extendedCheckRules: Seq[Rule[LogicalPlan]] = Nil
-
lazy val batches: Seq[Batch] = Seq(
Batch("Resolution", fixedPoint,
ResolveRelations ::
@@ -71,88 +65,11 @@ class Analyzer(catalog: Catalog,
TrimGroupingAliases ::
typeCoercionRules ++
extendedResolutionRules : _*),
- Batch("Check Analysis", Once,
- CheckResolution +:
- extendedCheckRules: _*),
Batch("Remove SubQueries", fixedPoint,
EliminateSubQueries)
)
/**
- * Makes sure all attributes and logical plans have been resolved.
- */
- object CheckResolution extends Rule[LogicalPlan] {
- def failAnalysis(msg: String) = { throw new AnalysisException(msg) }
-
- def apply(plan: LogicalPlan): LogicalPlan = {
- // We transform up and order the rules so as to catch the first possible failure instead
- // of the result of cascading resolution failures.
- plan.foreachUp {
- case operator: LogicalPlan =>
- operator transformExpressionsUp {
- case a: Attribute if !a.resolved =>
- val from = operator.inputSet.map(_.name).mkString(", ")
- a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
-
- case c: Cast if !c.resolved =>
- failAnalysis(
- s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")
-
- case b: BinaryExpression if !b.resolved =>
- failAnalysis(
- s"invalid expression ${b.prettyString} " +
- s"between ${b.left.simpleString} and ${b.right.simpleString}")
- }
-
- operator match {
- case f: Filter if f.condition.dataType != BooleanType =>
- failAnalysis(
- s"filter expression '${f.condition.prettyString}' " +
- s"of type ${f.condition.dataType.simpleString} is not a boolean.")
-
- case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
- def checkValidAggregateExpression(expr: Expression): Unit = expr match {
- case _: AggregateExpression => // OK
- case e: Attribute if !groupingExprs.contains(e) =>
- failAnalysis(
- s"expression '${e.prettyString}' is neither present in the group by, " +
- s"nor is it an aggregate function. " +
- "Add to group by or wrap in first() if you don't care which value you get.")
- case e if groupingExprs.contains(e) => // OK
- case e if e.references.isEmpty => // OK
- case e => e.children.foreach(checkValidAggregateExpression)
- }
-
- val cleaned = aggregateExprs.map(_.transform {
- // Should trim aliases around `GetField`s. These aliases are introduced while
- // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
- // (Should we just turn `GetField` into a `NamedExpression`?)
- case Alias(g, _) => g
- })
-
- cleaned.foreach(checkValidAggregateExpression)
-
- case o if o.children.nonEmpty &&
- !o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) =>
- val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",")
- val input = o.inputSet.map(_.prettyString).mkString(",")
-
- failAnalysis(s"resolved attributes $missingAttributes missing from $input")
-
- // Catch all
- case o if !o.resolved =>
- failAnalysis(
- s"unresolved operator ${operator.simpleString}")
-
- case _ => // Analysis successful!
- }
- }
-
- plan
- }
- }
-
- /**
* Removes no-op Alias expressions from the plan.
*/
object TrimGroupingAliases extends Rule[LogicalPlan] {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
new file mode 100644
index 0000000000..4e8fc892f3
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types._
+
+/**
+ * Throws user facing errors when passed invalid queries that fail to analyze.
+ */
+class CheckAnalysis {
+
+ /**
+ * Override to provide additional checks for correct analysis.
+ * These rules will be evaluated after our built-in check rules.
+ */
+ val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil
+
+ def failAnalysis(msg: String) = {
+ throw new AnalysisException(msg)
+ }
+
+ def apply(plan: LogicalPlan): Unit = {
+ // We transform up and order the rules so as to catch the first possible failure instead
+ // of the result of cascading resolution failures.
+ plan.foreachUp {
+ case operator: LogicalPlan =>
+ operator transformExpressionsUp {
+ case a: Attribute if !a.resolved =>
+ val from = operator.inputSet.map(_.name).mkString(", ")
+ a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
+
+ case c: Cast if !c.resolved =>
+ failAnalysis(
+ s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")
+
+ case b: BinaryExpression if !b.resolved =>
+ failAnalysis(
+ s"invalid expression ${b.prettyString} " +
+ s"between ${b.left.simpleString} and ${b.right.simpleString}")
+ }
+
+ operator match {
+ case f: Filter if f.condition.dataType != BooleanType =>
+ failAnalysis(
+ s"filter expression '${f.condition.prettyString}' " +
+ s"of type ${f.condition.dataType.simpleString} is not a boolean.")
+
+ case aggregatePlan@Aggregate(groupingExprs, aggregateExprs, child) =>
+ def checkValidAggregateExpression(expr: Expression): Unit = expr match {
+ case _: AggregateExpression => // OK
+ case e: Attribute if !groupingExprs.contains(e) =>
+ failAnalysis(
+ s"expression '${e.prettyString}' is neither present in the group by, " +
+ s"nor is it an aggregate function. " +
+ "Add to group by or wrap in first() if you don't care which value you get.")
+ case e if groupingExprs.contains(e) => // OK
+ case e if e.references.isEmpty => // OK
+ case e => e.children.foreach(checkValidAggregateExpression)
+ }
+
+ val cleaned = aggregateExprs.map(_.transform {
+ // Should trim aliases around `GetField`s. These aliases are introduced while
+ // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
+ // (Should we just turn `GetField` into a `NamedExpression`?)
+ case Alias(g, _) => g
+ })
+
+ cleaned.foreach(checkValidAggregateExpression)
+
+ case o if o.children.nonEmpty &&
+ !o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) =>
+ val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",")
+ val input = o.inputSet.map(_.prettyString).mkString(",")
+
+ failAnalysis(s"resolved attributes $missingAttributes missing from $input")
+
+ // Catch all
+ case o if !o.resolved =>
+ failAnalysis(
+ s"unresolved operator ${operator.simpleString}")
+
+ case _ => // Analysis successful!
+ }
+ }
+ extendedCheckRules.foreach(_(plan))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index aec7847356..c1dd5aa913 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -30,11 +30,21 @@ import org.apache.spark.sql.catalyst.dsl.plans._
class AnalysisSuite extends FunSuite with BeforeAndAfter {
val caseSensitiveCatalog = new SimpleCatalog(true)
val caseInsensitiveCatalog = new SimpleCatalog(false)
- val caseSensitiveAnalyze =
+
+ val caseSensitiveAnalyzer =
new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true)
- val caseInsensitiveAnalyze =
+ val caseInsensitiveAnalyzer =
new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false)
+ val checkAnalysis = new CheckAnalysis
+
+
+ def caseSensitiveAnalyze(plan: LogicalPlan) =
+ checkAnalysis(caseSensitiveAnalyzer(plan))
+
+ def caseInsensitiveAnalyze(plan: LogicalPlan) =
+ checkAnalysis(caseInsensitiveAnalyzer(plan))
+
val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
val testRelation2 = LocalRelation(
AttributeReference("a", StringType)(),
@@ -55,7 +65,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None)))
}
- assert(caseInsensitiveAnalyze(plan).resolved)
+ assert(caseInsensitiveAnalyzer(plan).resolved)
}
test("check project's resolved") {
@@ -71,11 +81,11 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
test("analyze project") {
assert(
- caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
+ caseSensitiveAnalyzer(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
Project(testRelation.output, testRelation))
assert(
- caseSensitiveAnalyze(
+ caseSensitiveAnalyzer(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
@@ -88,13 +98,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(e.getMessage().toLowerCase.contains("cannot resolve"))
assert(
- caseInsensitiveAnalyze(
+ caseInsensitiveAnalyzer(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
assert(
- caseInsensitiveAnalyze(
+ caseInsensitiveAnalyzer(
Project(Seq(UnresolvedAttribute("tBl.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
@@ -107,16 +117,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(e.getMessage == "Table Not Found: tAbLe")
assert(
- caseSensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
- testRelation)
+ caseSensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
assert(
- caseInsensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) ===
- testRelation)
+ caseInsensitiveAnalyzer(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation)
assert(
- caseInsensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
- testRelation)
+ caseInsensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
}
def errorTest(
@@ -177,7 +184,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
AttributeReference("d", DecimalType.Unlimited)(),
AttributeReference("e", ShortType)())
- val plan = caseInsensitiveAnalyze(
+ val plan = caseInsensitiveAnalyzer(
testRelation2.select(
'a / Literal(2) as 'div1,
'a / 'b as 'div2,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 69e5f6a07d..27ac398063 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -117,7 +117,7 @@ class DataFrame protected[sql](
this(sqlContext, {
val qe = sqlContext.executePlan(logicalPlan)
if (sqlContext.conf.dataFrameEagerAnalysis) {
- qe.analyzed // This should force analysis and throw errors if there are any
+ qe.assertAnalyzed() // This should force analysis and throw errors if there are any
}
qe
})
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 39f6c2f4bc..a08c0f5ce3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -52,8 +52,9 @@ private[spark] object SQLConf {
// This is used to set the default data source
val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default"
- // Whether to perform eager analysis on a DataFrame.
- val DATAFRAME_EAGER_ANALYSIS = "spark.sql.dataframe.eagerAnalysis"
+ // Whether to perform eager analysis when constructing a dataframe.
+ // Set to false when debugging requires the ability to look at invalid query plans.
+ val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis"
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 4bdaa02391..ce800e0754 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -114,7 +114,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
new Analyzer(catalog, functionRegistry, caseSensitive = true) {
override val extendedResolutionRules =
ExtractPythonUdfs ::
- sources.PreWriteCheck(catalog) ::
sources.PreInsertCastAndRename ::
Nil
}
@@ -1057,6 +1056,13 @@ class SQLContext(@transient val sparkContext: SparkContext)
Batch("Add exchange", Once, AddExchange(self)) :: Nil
}
+ @transient
+ protected[sql] lazy val checkAnalysis = new CheckAnalysis {
+ override val extendedCheckRules = Seq(
+ sources.PreWriteCheck(catalog)
+ )
+ }
+
/**
* :: DeveloperApi ::
* The primary workflow for executing relational queries using Spark. Designed to allow easy
@@ -1064,9 +1070,13 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@DeveloperApi
protected[sql] class QueryExecution(val logical: LogicalPlan) {
+ def assertAnalyzed(): Unit = checkAnalysis(analyzed)
lazy val analyzed: LogicalPlan = analyzer(logical)
- lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed)
+ lazy val withCachedData: LogicalPlan = {
+ assertAnalyzed
+ cacheManager.useCachedData(analyzed)
+ }
lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData)
// TODO: Don't just pick the first one...
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
index 36a9c0bdc4..8440581074 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
@@ -78,10 +78,10 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
/**
* A rule to do various checks before inserting into or writing to a data source table.
*/
-private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan] {
+private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => Unit) {
def failAnalysis(msg: String) = { throw new AnalysisException(msg) }
- def apply(plan: LogicalPlan): LogicalPlan = {
+ def apply(plan: LogicalPlan): Unit = {
plan.foreach {
case i @ logical.InsertIntoTable(
l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) =>
@@ -93,7 +93,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan
val srcRelations = query.collect {
case LogicalRelation(src: BaseRelation) => src
}
- if (srcRelations.exists(src => src == t)) {
+ if (srcRelations.contains(t)) {
failAnalysis(
"Cannot insert overwrite into table that is also being read from.")
} else {
@@ -119,7 +119,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan
val srcRelations = query.collect {
case LogicalRelation(src: BaseRelation) => src
}
- if (srcRelations.exists(src => src == dest)) {
+ if (srcRelations.contains(dest)) {
failAnalysis(
s"Cannot overwrite table $tableName that is also being read from.")
} else {
@@ -134,7 +134,5 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan
case _ => // OK
}
-
- plan
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
index 0ec6881d7a..91c6367371 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -30,7 +30,6 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
override protected[sql] lazy val analyzer: Analyzer =
new Analyzer(catalog, functionRegistry, caseSensitive = false) {
override val extendedResolutionRules =
- PreWriteCheck(catalog) ::
PreInsertCastAndRename ::
Nil
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index 5682e5a2bc..b5b16f9546 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -205,7 +205,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
val message = intercept[AnalysisException] {
sql(
s"""
- |INSERT OVERWRITE TABLE oneToTen SELECT a FROM jt
+ |INSERT OVERWRITE TABLE oneToTen SELECT CAST(a AS INT) FROM jt
""".stripMargin)
}.getMessage
assert(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 2e205e67c0..c439dfe0a7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -268,7 +268,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
catalog.PreInsertionCasts ::
ExtractPythonUdfs ::
ResolveUdtfsAlias ::
- sources.PreWriteCheck(catalog) ::
sources.PreInsertCastAndRename ::
Nil
}