aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala12
2 files changed, 21 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 8e39f79d2c..9608e15c0f 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -135,6 +135,15 @@ package object dsl {
implicit def symbolToUnresolvedAttribute(s: Symbol): analysis.UnresolvedAttribute =
analysis.UnresolvedAttribute(s.name)
+ /** Converts $"col name" into an [[analysis.UnresolvedAttribute]]. */
+ implicit class StringToAttributeConversionHelper(val sc: StringContext) {
+ // Note that if we make ExpressionConversions an object rather than a trait, we can
+ // then make this a value class to avoid the small penalty of runtime instantiation.
+ def $(args: Any*): analysis.UnresolvedAttribute = {
+ analysis.UnresolvedAttribute(sc.s(args :_*))
+ }
+ }
+
def sum(e: Expression) = Sum(e)
def sumDistinct(e: Expression) = SumDistinct(e)
def count(e: Expression) = Count(e)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index c0b9cf5163..ab88f3ad10 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -56,6 +56,18 @@ class DslQuerySuite extends QueryTest {
)
}
+ test("convert $\"attribute name\" into unresolved attribute") {
+ checkAnswer(
+ testData.where($"key" === 1).select($"value"),
+ Seq(Seq("1")))
+ }
+
+ test("convert Scala Symbol 'attrname into unresolved attribute") {
+ checkAnswer(
+ testData.where('key === 1).select('value),
+ Seq(Seq("1")))
+ }
+
test("select *") {
checkAnswer(
testData.select(Star(None)),