aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-04-15 13:39:12 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-15 13:39:12 -0700
commit4754e16f4746ebd882b2ce7f1efc6e4d4408922c (patch)
tree873f70e80640277683a56b7ef09c7561ac336361 /sql
parent557a797a273f1668065806cba53e19e6134a66d3 (diff)
downloadspark-4754e16f4746ebd882b2ce7f1efc6e4d4408922c.tar.gz
spark-4754e16f4746ebd882b2ce7f1efc6e4d4408922c.tar.bz2
spark-4754e16f4746ebd882b2ce7f1efc6e4d4408922c.zip
[SPARK-6898][SQL] completely support special chars in column names
Even if we wrap column names in backticks like `` `a#$b.c` ``, we still handle the "." inside column name specially. I think it's fragile to use a special char to split name parts, why not put name parts in `UnresolvedAttribute` directly? Author: Wenchen Fan <cloud0fan@outlook.com> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <michael@databricks.com> Closes #5511 from cloud-fan/6898 and squashes the following commits: 48e3e57 [Wenchen Fan] more style fix 820dc45 [Wenchen Fan] do not ignore newName in UnresolvedAttribute d81ad43 [Wenchen Fan] fix style 11699d6 [Wenchen Fan] completely support special chars in column names
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala27
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala13
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala2
9 files changed, 52 insertions, 33 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 9a3531ceb3..0af969cc5c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -381,13 +381,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
| "(" ~> expression <~ ")"
| function
| dotExpressionHeader
- | ident ^^ UnresolvedAttribute
+ | ident ^^ {case i => UnresolvedAttribute.quoted(i)}
| signedPrimary
| "~" ~> expression ^^ BitwiseNot
)
protected lazy val dotExpressionHeader: Parser[Expression] =
(ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ {
- case i1 ~ i2 ~ rest => UnresolvedAttribute((Seq(i1, i2) ++ rest).mkString("."))
+ case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8b68b0df35..cb49e5ad55 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -297,14 +297,15 @@ class Analyzer(
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp {
- case u @ UnresolvedAttribute(name) if resolver(name, VirtualColumn.groupingIdName) &&
+ case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 &&
+ resolver(nameParts(0), VirtualColumn.groupingIdName) &&
q.isInstanceOf[GroupingAnalytics] =>
// Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics
q.asInstanceOf[GroupingAnalytics].gid
- case u @ UnresolvedAttribute(name) =>
+ case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
- withPosition(u) { q.resolveChildren(name, resolver).getOrElse(u) }
+ withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
logDebug(s"Resolving $u to $result")
result
case UnresolvedGetField(child, fieldName) if child.resolved =>
@@ -383,12 +384,12 @@ class Analyzer(
child: LogicalPlan,
grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
// Find any attributes that remain unresolved in the sort.
- val unresolved: Seq[String] =
- ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
+ val unresolved: Seq[Seq[String]] =
+ ordering.flatMap(_.collect { case UnresolvedAttribute(nameParts) => nameParts })
// Create a map from name, to resolved attributes, when the desired name can be found
// prior to the projection.
- val resolved: Map[String, NamedExpression] =
+ val resolved: Map[Seq[String], NamedExpression] =
unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap
// Construct a set that contains all of the attributes that we need to evaluate the
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index fa02111385..1155dac28f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -46,8 +46,12 @@ trait CheckAnalysis {
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
if (operator.childrenResolved) {
+ val nameParts = a match {
+ case UnresolvedAttribute(nameParts) => nameParts
+ case _ => Seq(a.name)
+ }
// Throw errors for specific problems with get field.
- operator.resolveChildren(a.name, resolver, throwErrors = true)
+ operator.resolveChildren(nameParts, resolver, throwErrors = true)
}
val from = operator.inputSet.map(_.name).mkString(", ")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 300e9ba187..3f567e3e8b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -49,7 +49,12 @@ case class UnresolvedRelation(
/**
* Holds the name of an attribute that has yet to be resolved.
*/
-case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] {
+case class UnresolvedAttribute(nameParts: Seq[String])
+ extends Attribute with trees.LeafNode[Expression] {
+
+ def name: String =
+ nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
+
override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
@@ -59,7 +64,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
override def newInstance(): UnresolvedAttribute = this
override def withNullability(newNullability: Boolean): UnresolvedAttribute = this
override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this
- override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute(name)
+ override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
// Unresolved attributes are transient at compile time and don't get evaluated during execution.
override def eval(input: Row = null): EvaluatedType =
@@ -68,6 +73,11 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
override def toString: String = s"'$name"
}
+object UnresolvedAttribute {
+ def apply(name: String): UnresolvedAttribute = new UnresolvedAttribute(name.split("\\."))
+ def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name))
+}
+
case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression {
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 579a0fb8d3..ae4620a4e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -19,12 +19,11 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, Resolver}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.trees
-import org.apache.spark.sql.types.{ArrayType, StructType, StructField}
abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
@@ -111,10 +110,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
*/
def resolveChildren(
- name: String,
+ nameParts: Seq[String],
resolver: Resolver,
throwErrors: Boolean = false): Option[NamedExpression] =
- resolve(name, children.flatMap(_.output), resolver, throwErrors)
+ resolve(nameParts, children.flatMap(_.output), resolver, throwErrors)
/**
* Optionally resolves the given string to a [[NamedExpression]] based on the output of this
@@ -122,10 +121,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* `[scope].AttributeName.[nested].[fields]...`.
*/
def resolve(
- name: String,
+ nameParts: Seq[String],
resolver: Resolver,
throwErrors: Boolean = false): Option[NamedExpression] =
- resolve(name, output, resolver, throwErrors)
+ resolve(nameParts, output, resolver, throwErrors)
/**
* Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
@@ -135,7 +134,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* See the comment above `candidates` variable in resolve() for semantics the returned data.
*/
private def resolveAsTableColumn(
- nameParts: Array[String],
+ nameParts: Seq[String],
resolver: Resolver,
attribute: Attribute): Option[(Attribute, List[String])] = {
assert(nameParts.length > 1)
@@ -155,7 +154,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* See the comment above `candidates` variable in resolve() for semantics the returned data.
*/
private def resolveAsColumn(
- nameParts: Array[String],
+ nameParts: Seq[String],
resolver: Resolver,
attribute: Attribute): Option[(Attribute, List[String])] = {
if (resolver(attribute.name, nameParts.head)) {
@@ -167,13 +166,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
/** Performs attribute resolution given a name and a sequence of possible attributes. */
protected def resolve(
- name: String,
+ nameParts: Seq[String],
input: Seq[Attribute],
resolver: Resolver,
throwErrors: Boolean): Option[NamedExpression] = {
- val parts = name.split("\\.")
-
// A sequence of possible candidate matches.
// Each candidate is a tuple. The first element is a resolved attribute, followed by a list
// of parts that are to be resolved.
@@ -182,9 +179,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// and the second element will be List("c").
var candidates: Seq[(Attribute, List[String])] = {
// If the name has 2 or more parts, try to resolve it as `table.column` first.
- if (parts.length > 1) {
+ if (nameParts.length > 1) {
input.flatMap { option =>
- resolveAsTableColumn(parts, resolver, option)
+ resolveAsTableColumn(nameParts, resolver, option)
}
} else {
Seq.empty
@@ -194,10 +191,12 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// If none of attributes match `table.column` pattern, we try to resolve it as a column.
if (candidates.isEmpty) {
candidates = input.flatMap { candidate =>
- resolveAsColumn(parts, resolver, candidate)
+ resolveAsColumn(nameParts, resolver, candidate)
}
}
+ def name = UnresolvedAttribute(nameParts).name
+
candidates.distinct match {
// One match, no nested fields, use it.
case Seq((a, Nil)) => Some(a)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 6e3d6b9263..e10ddfdf51 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -27,8 +27,6 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import scala.collection.immutable
-
class AnalysisSuite extends FunSuite with BeforeAndAfter {
val caseSensitiveCatalog = new SimpleCatalog(true)
val caseInsensitiveCatalog = new SimpleCatalog(false)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 94ae2d65fd..3235f85d5b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -158,7 +158,7 @@ class DataFrame private[sql](
}
protected[sql] def resolve(colName: String): NamedExpression = {
- queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse {
+ queryExecution.analyzed.resolve(colName.split("\\."), sqlContext.analyzer.resolver).getOrElse {
throw new AnalysisException(
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
}
@@ -166,7 +166,7 @@ class DataFrame private[sql](
protected[sql] def numericColumns: Seq[Expression] = {
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
- queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get
+ queryExecution.analyzed.resolve(n.name.split("\\."), sqlContext.analyzer.resolver).get
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 4c48dca444..d739e550f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.TestData._
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
import org.apache.spark.sql.types._
-
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
// Make sure the tables are loaded.
TestData
@@ -1125,7 +1124,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
val data = sparkContext.parallelize(
Seq("""{"key?number1": "value1", "key.number2": "value2"}"""))
jsonRDD(data).registerTempTable("records")
- sql("SELECT `key?number1` FROM records")
+ sql("SELECT `key?number1`, `key.number2` FROM records")
}
test("SPARK-3814 Support Bitwise & operator") {
@@ -1225,4 +1224,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1))
checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
}
+
+ test("SPARK-6898: complete support for special chars in column names") {
+ jsonRDD(sparkContext.makeRDD(
+ """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil))
+ .registerTempTable("t")
+
+ checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1))
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 53a204b8c2..fd305eb480 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1101,7 +1101,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token(".", qualifier :: Token(attr, Nil) :: Nil) =>
nodeToExpr(qualifier) match {
case UnresolvedAttribute(qualifierName) =>
- UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr))
+ UnresolvedAttribute(qualifierName :+ cleanIdentifier(attr))
case other => UnresolvedGetField(other, attr)
}