aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-02-16 12:32:56 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-16 12:32:56 -0800
commit104b2c45805ce0a9c86e2823f402de6e9f0aee81 (patch)
treeaf1fbfad7997e5560c730f5ff3058d91cd085c5a /sql
parent275a0c08134dea1896eab73a8e017256900fb1db (diff)
downloadspark-104b2c45805ce0a9c86e2823f402de6e9f0aee81.tar.gz
spark-104b2c45805ce0a9c86e2823f402de6e9f0aee81.tar.bz2
spark-104b2c45805ce0a9c86e2823f402de6e9f0aee81.zip
[SQL] Initial support for reporting location of error in sql string
Author: Michael Armbrust <michael@databricks.com> Closes #4587 from marmbrus/position and squashes the following commits: 0810052 [Michael Armbrust] fix tests 395c019 [Michael Armbrust] Merge remote-tracking branch 'marmbrus/position' into position e155dce [Michael Armbrust] more errors f3efa51 [Michael Armbrust] Update AnalysisException.scala d45ff60 [Michael Armbrust] [SQL] Initial support for reporting location of error in sql string
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala60
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala14
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala9
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala47
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala163
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala4
11 files changed, 314 insertions, 39 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
index 871d560b9d..15add84878 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
@@ -17,7 +17,22 @@
package org.apache.spark.sql
+import org.apache.spark.annotation.DeveloperApi
+
/**
+ * :: DeveloperApi ::
* Thrown when a query fails to analyze, usually because the query itself is invalid.
*/
-class AnalysisException(message: String) extends Exception(message) with Serializable
+@DeveloperApi
+class AnalysisException protected[sql] (
+ val message: String,
+ val line: Option[Int] = None,
+ val startPosition: Option[Int] = None)
+ extends Exception with Serializable {
+
+ override def getMessage: String = {
+ val lineAnnotation = line.map(l => s" line $l").getOrElse("")
+ val positionAnnotation = startPosition.map(p => s" pos $p").getOrElse("")
+ s"$message;$lineAnnotation$positionAnnotation"
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 58a7003977..aa4320bd58 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -85,7 +85,7 @@ class Analyzer(catalog: Catalog,
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
val from = operator.inputSet.map(_.name).mkString(", ")
- failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
+ a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
case c: Cast if !c.resolved =>
failAnalysis(
@@ -246,12 +246,21 @@ class Analyzer(catalog: Catalog,
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
object ResolveRelations extends Rule[LogicalPlan] {
+ def getTable(u: UnresolvedRelation) = {
+ try {
+ catalog.lookupRelation(u.tableIdentifier, u.alias)
+ } catch {
+ case _: NoSuchTableException =>
+ u.failAnalysis(s"no such table ${u.tableIdentifier}")
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case i @ InsertIntoTable(UnresolvedRelation(tableIdentifier, alias), _, _, _) =>
+ case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _) =>
i.copy(
- table = EliminateSubQueries(catalog.lookupRelation(tableIdentifier, alias)))
- case UnresolvedRelation(tableIdentifier, alias) =>
- catalog.lookupRelation(tableIdentifier, alias)
+ table = EliminateSubQueries(getTable(u)))
+ case u: UnresolvedRelation =>
+ getTable(u)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index f57eab2460..bf97215ee6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -22,6 +22,12 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery}
/**
+ * Thrown by a catalog when a table cannot be found. The analzyer will rethrow the exception
+ * as an AnalysisException with the correct position information.
+ */
+class NoSuchTableException extends Exception
+
+/**
* An interface for looking up relations by name. Used by an [[Analyzer]].
*/
trait Catalog {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
index 5dc9d0e566..e95f19e69e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.catalyst
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.trees.TreeNode
+
/**
* Provides a logical query plan [[Analyzer]] and supporting classes for performing analysis.
* Analysis consists of translating [[UnresolvedAttribute]]s and [[UnresolvedRelation]]s
@@ -32,4 +35,11 @@ package object analysis {
val caseInsensitiveResolution = (a: String, b: String) => a.equalsIgnoreCase(b)
val caseSensitiveResolution = (a: String, b: String) => a == b
+
+ implicit class AnalysisErrorAt(t: TreeNode[_]) {
+ /** Fails the analysis at the point where a specific tree node was parsed. */
+ def failAnalysis(msg: String) = {
+ throw new AnalysisException(msg, t.origin.line, t.origin.startPosition)
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index e0930b056d..109671bdca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -22,9 +22,42 @@ import org.apache.spark.sql.catalyst.errors._
/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
private class MutableInt(var i: Int)
+case class Origin(
+ line: Option[Int] = None,
+ startPosition: Option[Int] = None)
+
+/**
+ * Provides a location for TreeNodes to ask about the context of their origin. For example, which
+ * line of code is currently being parsed.
+ */
+object CurrentOrigin {
+ private val value = new ThreadLocal[Origin]() {
+ override def initialValue: Origin = Origin()
+ }
+
+ def get = value.get()
+ def set(o: Origin) = value.set(o)
+
+ def reset() = value.set(Origin())
+
+ def setPosition(line: Int, start: Int) = {
+ value.set(
+ value.get.copy(line = Some(line), startPosition = Some(start)))
+ }
+
+ def withOrigin[A](o: Origin)(f: => A): A = {
+ set(o)
+ val ret = try f finally { reset() }
+ reset()
+ ret
+ }
+}
+
abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
self: BaseType with Product =>
+ val origin = CurrentOrigin.get
+
/** Returns a Seq of the children of this node */
def children: Seq[BaseType]
@@ -150,7 +183,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
* @param rule the function used to transform this nodes children
*/
def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = {
- val afterRule = rule.applyOrElse(this, identity[BaseType])
+ val afterRule = CurrentOrigin.withOrigin(origin) {
+ rule.applyOrElse(this, identity[BaseType])
+ }
+
// Check if unchanged and then possibly return old copy to avoid gc churn.
if (this fastEquals afterRule) {
transformChildrenDown(rule)
@@ -210,9 +246,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
val afterRuleOnChildren = transformChildrenUp(rule);
if (this fastEquals afterRuleOnChildren) {
- rule.applyOrElse(this, identity[BaseType])
+ CurrentOrigin.withOrigin(origin) {
+ rule.applyOrElse(this, identity[BaseType])
+ }
} else {
- rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
+ CurrentOrigin.withOrigin(origin) {
+ rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
+ }
}
}
@@ -268,12 +308,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
*/
def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
try {
- // Skip no-arg constructors that are just there for kryo.
- val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head
- if (otherCopyArgs.isEmpty) {
- defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
- } else {
- defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type]
+ CurrentOrigin.withOrigin(origin) {
+ // Skip no-arg constructors that are just there for kryo.
+ val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head
+ if (otherCopyArgs.isEmpty) {
+ defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
+ } else {
+ defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type]
+ }
}
} catch {
case e: java.lang.IllegalArgumentException =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index cdb843f959..e7ce92a216 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -104,4 +104,18 @@ class TreeNodeSuite extends FunSuite {
assert(actual === Dummy(None))
}
+ test("preserves origin") {
+ CurrentOrigin.setPosition(1,1)
+ val add = Add(Literal(1), Literal(1))
+ CurrentOrigin.reset()
+
+ val transformed = add transform {
+ case Literal(1, _) => Literal(2)
+ }
+
+ assert(transformed.origin.line.isDefined)
+ assert(transformed.origin.startPosition.isDefined)
+ }
+
+
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index f82778c876..12f86a04a3 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -31,8 +31,8 @@ import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException}
import org.apache.hadoop.util.ReflectionUtils
import org.apache.spark.Logging
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.catalyst.analysis.{Catalog, OverrideCatalog}
+import org.apache.spark.sql.{AnalysisException, SQLContext}
+import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Catalog, OverrideCatalog}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
@@ -154,7 +154,10 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse(
hive.sessionState.getCurrentDatabase)
val tblName = tableIdent.last
- val table = client.getTable(databaseName, tblName)
+ val table = try client.getTable(databaseName, tblName) catch {
+ case te: org.apache.hadoop.hive.ql.metadata.InvalidTableException =>
+ throw new NoSuchTableException
+ }
if (table.getProperty("spark.sql.sources.provider") != null) {
cachedDataSourceTables(QualifiedTableName(databaseName, tblName).toLowerCase)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 5269460e5b..5a1825a87d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive
import java.sql.Date
+
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.hive.conf.HiveConf
@@ -27,13 +28,14 @@ import org.apache.hadoop.hive.ql.lib.Node
import org.apache.hadoop.hive.ql.metadata.Table
import org.apache.hadoop.hive.ql.parse._
import org.apache.hadoop.hive.ql.plan.PlanUtils
-import org.apache.spark.sql.SparkSQLParser
+import org.apache.spark.sql.{AnalysisException, SparkSQLParser}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.execution.ExplainCommand
import org.apache.spark.sql.sources.DescribeCommand
import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema}
@@ -211,12 +213,6 @@ private[hive] object HiveQl {
}
}
- class ParseException(sql: String, cause: Throwable)
- extends Exception(s"Failed to parse: $sql", cause)
-
- class SemanticException(msg: String)
- extends Exception(s"Error in semantic analysis: $msg")
-
/**
* Returns the AST for the given SQL string.
*/
@@ -236,8 +232,10 @@ private[hive] object HiveQl {
/** Returns a LogicalPlan for a given HiveQL string. */
def parseSql(sql: String): LogicalPlan = hqlParser(sql)
+ val errorRegEx = "line (\\d+):(\\d+) (.*)".r
+
/** Creates LogicalPlan for a given HiveQL string. */
- def createPlan(sql: String) = {
+ def createPlan(sql: String): LogicalPlan = {
try {
val tree = getAst(sql)
if (nativeCommands contains tree.getText) {
@@ -249,14 +247,23 @@ private[hive] object HiveQl {
}
}
} catch {
- case e: Exception => throw new ParseException(sql, e)
- case e: NotImplementedError => sys.error(
- s"""
- |Unsupported language features in query: $sql
- |${dumpTree(getAst(sql))}
- |$e
- |${e.getStackTrace.head}
- """.stripMargin)
+ case pe: org.apache.hadoop.hive.ql.parse.ParseException =>
+ pe.getMessage match {
+ case errorRegEx(line, start, message) =>
+ throw new AnalysisException(message, Some(line.toInt), Some(start.toInt))
+ case otherMessage =>
+ throw new AnalysisException(otherMessage)
+ }
+ case e: Exception =>
+ throw new AnalysisException(e.getMessage)
+ case e: NotImplementedError =>
+ throw new AnalysisException(
+ s"""
+ |Unsupported language features in query: $sql
+ |${dumpTree(getAst(sql))}
+ |$e
+ |${e.getStackTrace.head}
+ """.stripMargin)
}
}
@@ -292,6 +299,7 @@ private[hive] object HiveQl {
/** @return matches of the form (tokenName, children). */
def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match {
case t: ASTNode =>
+ CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine)
Some((t.getText,
Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]]))
case _ => None
@@ -1278,7 +1286,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0)
: StringBuilder = {
node match {
- case a: ASTNode => builder.append((" " * indent) + a.getText + "\n")
+ case a: ASTNode => builder.append(
+ (" " * indent) + a.getText + " " +
+ a.getLine + ", " +
+ a.getTokenStartIndex + "," +
+ a.getTokenStopIndex + ", " +
+ a.getCharPositionInLine + "\n")
case other => sys.error(s"Non ASTNode encountered: $other")
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index 7c8b5205e2..44d24273e7 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.hive
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest}
import org.apache.spark.storage.RDDBlockId
class CachedTableSuite extends QueryTest {
@@ -96,7 +96,7 @@ class CachedTableSuite extends QueryTest {
cacheTable("test")
sql("SELECT * FROM test").collect()
sql("DROP TABLE test")
- intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] {
+ intercept[AnalysisException] {
sql("SELECT * FROM test").collect()
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
new file mode 100644
index 0000000000..f04437c595
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.io.{OutputStream, PrintStream}
+
+import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.{AnalysisException, QueryTest}
+
+import scala.util.Try
+
+class ErrorPositionSuite extends QueryTest {
+
+ positionTest("unresolved attribute 1",
+ "SELECT x FROM src", "x")
+
+ positionTest("unresolved attribute 2",
+ "SELECT x FROM src", "x")
+
+ positionTest("unresolved attribute 3",
+ "SELECT key, x FROM src", "x")
+
+ positionTest("unresolved attribute 4",
+ """SELECT key,
+ |x FROM src
+ """.stripMargin, "x")
+
+ positionTest("unresolved attribute 5",
+ """SELECT key,
+ | x FROM src
+ """.stripMargin, "x")
+
+ positionTest("unresolved attribute 6",
+ """SELECT key,
+ |
+ | 1 + x FROM src
+ """.stripMargin, "x")
+
+ positionTest("unresolved attribute 7",
+ """SELECT key,
+ |
+ | 1 + x + 1 FROM src
+ """.stripMargin, "x")
+
+ positionTest("multi-char unresolved attribute",
+ """SELECT key,
+ |
+ | 1 + abcd + 1 FROM src
+ """.stripMargin, "abcd")
+
+ positionTest("unresolved attribute group by",
+ """SELECT key FROM src GROUP BY
+ |x
+ """.stripMargin, "x")
+
+ positionTest("unresolved attribute order by",
+ """SELECT key FROM src ORDER BY
+ |x
+ """.stripMargin, "x")
+
+ positionTest("unresolved attribute where",
+ """SELECT key FROM src
+ |WHERE x = true
+ """.stripMargin, "x")
+
+ positionTest("unresolved attribute backticks",
+ "SELECT `x` FROM src", "`x`")
+
+ positionTest("parse error",
+ "SELECT WHERE", "WHERE")
+
+ positionTest("bad relation",
+ "SELECT * FROM badTable", "badTable")
+
+ ignore("other expressions") {
+ positionTest("bad addition",
+ "SELECT 1 + array(1)", "1 + array")
+ }
+
+ /** Hive can be very noisy, messing up the output of our tests. */
+ private def quietly[A](f: => A): A = {
+ val origErr = System.err
+ val origOut = System.out
+ try {
+ System.setErr(new PrintStream(new OutputStream {
+ def write(b: Int) = {}
+ }))
+ System.setOut(new PrintStream(new OutputStream {
+ def write(b: Int) = {}
+ }))
+
+ f
+ } finally {
+ System.setErr(origErr)
+ System.setOut(origOut)
+ }
+ }
+
+ /**
+ * Creates a test that checks to see if the error thrown when analyzing a given query includes
+ * the location of the given token in the query string.
+ *
+ * @param name the name of the test
+ * @param query the query to analyze
+ * @param token a unique token in the string that should be indicated by the exception
+ */
+ def positionTest(name: String, query: String, token: String) = {
+ def parseTree =
+ Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)))).getOrElse("<failed to parse>")
+
+ test(name) {
+ val error = intercept[AnalysisException] {
+ quietly(sql(query))
+ }
+ val (line, expectedLineNum) = query.split("\n").zipWithIndex.collect {
+ case (l, i) if l.contains(token) => (l, i + 1)
+ }.headOption.getOrElse(sys.error(s"Invalid test. Token $token not in $query"))
+ val actualLine = error.line.getOrElse {
+ fail(
+ s"line not returned for error '${error.getMessage}' on token $token\n$parseTree"
+ )
+ }
+ assert(actualLine === expectedLineNum, "wrong line")
+
+ val expectedStart = line.indexOf(token)
+ val actualStart = error.startPosition.getOrElse {
+ fail(
+ s"start not returned for error on token $token\n" +
+ HiveQl.dumpTree(HiveQl.getAst(query))
+ )
+ }
+ assert(expectedStart === actualStart,
+ s"""Incorrect start position.
+ |== QUERY ==
+ |$query
+ |
+ |== AST ==
+ |$parseTree
+ |
+ |Actual: $actualStart, Expected: $expectedStart
+ |$line
+ |${" " * actualStart}^
+ |0123456789 123456789 1234567890
+ | 2 3
+ """.stripMargin)
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 9788259383..e8d9eec3d8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{QueryTest, Row, SQLConf}
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf}
case class Nested1(f1: Nested2)
case class Nested2(f2: Nested3)
@@ -185,7 +185,7 @@ class SQLQuerySuite extends QueryTest {
sql("SELECT * FROM test_ctas_1234"),
sql("SELECT * FROM nested").collect().toSeq)
- intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] {
+ intercept[AnalysisException] {
sql("CREATE TABLE test_ctas_12345 AS SELECT * from notexists").collect()
}
}