aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-08-06 17:31:16 -0700
committerReynold Xin <rxin@databricks.com>2015-08-06 17:31:16 -0700
commit0867b23c74a3e6347d718b67ddabff17b468eded (patch)
tree4b2bc2efd0a7fa8e5eb22259a671570bdafc0a84
parent4e70e8256ce2f45b438642372329eac7b1e9e8cf (diff)
downloadspark-0867b23c74a3e6347d718b67ddabff17b468eded.tar.gz
spark-0867b23c74a3e6347d718b67ddabff17b468eded.tar.bz2
spark-0867b23c74a3e6347d718b67ddabff17b468eded.zip
[SPARK-9650][SQL] Fix quoting behavior on interpolated column names
Make sure that `$"column"` is consistent with other methods with respect to backticks. Adds a bunch of tests for various ways of constructing columns. Author: Michael Armbrust <michael@databricks.com> Closes #7969 from marmbrus/namesWithDots and squashes the following commits: 53ef3d7 [Michael Armbrust] [SPARK-9650][SQL] Fix quoting behavior on interpolated column names 2bf7a92 [Michael Armbrust] WIP
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala57
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala42
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala68
5 files changed, 128 insertions, 43 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 03da45b09f..43ee319193 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.errors
import org.apache.spark.sql.catalyst.expressions._
@@ -69,8 +70,64 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
}
object UnresolvedAttribute {
+ /**
+ * Creates an [[UnresolvedAttribute]], parsing segments separated by dots ('.').
+ */
def apply(name: String): UnresolvedAttribute = new UnresolvedAttribute(name.split("\\."))
+
+ /**
+ * Creates an [[UnresolvedAttribute]], from a single quoted string (for example using backticks in
+ * HiveQL. Since the string is consider quoted, no processing is done on the name.
+ */
def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name))
+
+ /**
+ * Creates an [[UnresolvedAttribute]] from a string in an embedded language. In this case
+ * we treat it as a quoted identifier, except for '.', which must be further quoted using
+ * backticks if it is part of a column name.
+ */
+ def quotedString(name: String): UnresolvedAttribute =
+ new UnresolvedAttribute(parseAttributeName(name))
+
+ /**
+ * Used to split attribute name by dot with backticks rule.
+ * Backticks must appear in pairs, and the quoted string must be a complete name part,
+ * which means `ab..c`e.f is not allowed.
+ * Escape character is not supported now, so we can't use backtick inside name part.
+ */
+ def parseAttributeName(name: String): Seq[String] = {
+ def e = new AnalysisException(s"syntax error in attribute name: $name")
+ val nameParts = scala.collection.mutable.ArrayBuffer.empty[String]
+ val tmp = scala.collection.mutable.ArrayBuffer.empty[Char]
+ var inBacktick = false
+ var i = 0
+ while (i < name.length) {
+ val char = name(i)
+ if (inBacktick) {
+ if (char == '`') {
+ inBacktick = false
+ if (i + 1 < name.length && name(i + 1) != '.') throw e
+ } else {
+ tmp += char
+ }
+ } else {
+ if (char == '`') {
+ if (tmp.nonEmpty) throw e
+ inBacktick = true
+ } else if (char == '.') {
+ if (name(i - 1) == '.' || i == name.length - 1) throw e
+ nameParts += tmp.mkString
+ tmp.clear()
+ } else {
+ tmp += char
+ }
+ }
+ i += 1
+ }
+ if (inBacktick) throw e
+ nameParts += tmp.mkString
+ nameParts.toSeq
+ }
}
case class UnresolvedFunction(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 9b52f02009..c290e6acb3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -179,47 +179,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
def resolveQuoted(
name: String,
resolver: Resolver): Option[NamedExpression] = {
- resolve(parseAttributeName(name), output, resolver)
- }
-
- /**
- * Internal method, used to split attribute name by dot with backticks rule.
- * Backticks must appear in pairs, and the quoted string must be a complete name part,
- * which means `ab..c`e.f is not allowed.
- * Escape character is not supported now, so we can't use backtick inside name part.
- */
- private def parseAttributeName(name: String): Seq[String] = {
- val e = new AnalysisException(s"syntax error in attribute name: $name")
- val nameParts = scala.collection.mutable.ArrayBuffer.empty[String]
- val tmp = scala.collection.mutable.ArrayBuffer.empty[Char]
- var inBacktick = false
- var i = 0
- while (i < name.length) {
- val char = name(i)
- if (inBacktick) {
- if (char == '`') {
- inBacktick = false
- if (i + 1 < name.length && name(i + 1) != '.') throw e
- } else {
- tmp += char
- }
- } else {
- if (char == '`') {
- if (tmp.nonEmpty) throw e
- inBacktick = true
- } else if (char == '.') {
- if (name(i - 1) == '.' || i == name.length - 1) throw e
- nameParts += tmp.mkString
- tmp.clear()
- } else {
- tmp += char
- }
- }
- i += 1
- }
- if (inBacktick) throw e
- nameParts += tmp.mkString
- nameParts.toSeq
+ resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 75365fbcec..27bd084847 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -54,7 +54,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
def this(name: String) = this(name match {
case "*" => UnresolvedStar(None)
case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2)))
- case _ => UnresolvedAttribute(name)
+ case _ => UnresolvedAttribute.quotedString(name)
})
/** Creates a column based on the given expression. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 6f8ffb5440..075c0ea254 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -343,7 +343,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
implicit class StringToColumn(val sc: StringContext) {
def $(args: Any*): ColumnName = {
- new ColumnName(sc.s(args : _*))
+ new ColumnName(sc.s(args: _*))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index e1b3443d74..6a09a3b72c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -32,6 +32,74 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
override def sqlContext(): SQLContext = ctx
+ test("column names with space") {
+ val df = Seq((1, "a")).toDF("name with space", "name.with.dot")
+
+ checkAnswer(
+ df.select(df("name with space")),
+ Row(1) :: Nil)
+
+ checkAnswer(
+ df.select($"name with space"),
+ Row(1) :: Nil)
+
+ checkAnswer(
+ df.select(col("name with space")),
+ Row(1) :: Nil)
+
+ checkAnswer(
+ df.select("name with space"),
+ Row(1) :: Nil)
+
+ checkAnswer(
+ df.select(expr("`name with space`")),
+ Row(1) :: Nil)
+ }
+
+ test("column names with dot") {
+ val df = Seq((1, "a")).toDF("name with space", "name.with.dot").as("a")
+
+ checkAnswer(
+ df.select(df("`name.with.dot`")),
+ Row("a") :: Nil)
+
+ checkAnswer(
+ df.select($"`name.with.dot`"),
+ Row("a") :: Nil)
+
+ checkAnswer(
+ df.select(col("`name.with.dot`")),
+ Row("a") :: Nil)
+
+ checkAnswer(
+ df.select("`name.with.dot`"),
+ Row("a") :: Nil)
+
+ checkAnswer(
+ df.select(expr("`name.with.dot`")),
+ Row("a") :: Nil)
+
+ checkAnswer(
+ df.select(df("a.`name.with.dot`")),
+ Row("a") :: Nil)
+
+ checkAnswer(
+ df.select($"a.`name.with.dot`"),
+ Row("a") :: Nil)
+
+ checkAnswer(
+ df.select(col("a.`name.with.dot`")),
+ Row("a") :: Nil)
+
+ checkAnswer(
+ df.select("a.`name.with.dot`"),
+ Row("a") :: Nil)
+
+ checkAnswer(
+ df.select(expr("a.`name.with.dot`")),
+ Row("a") :: Nil)
+ }
+
test("alias") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
assert(df.select(df("a").as("b")).columns.head === "b")