diff options
-rwxr-xr-x | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala | 9 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala | 12 |
2 files changed, 21 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 8e39f79d2c..9608e15c0f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -135,6 +135,15 @@ package object dsl { implicit def symbolToUnresolvedAttribute(s: Symbol): analysis.UnresolvedAttribute = analysis.UnresolvedAttribute(s.name) + /** Converts $"col name" into an [[analysis.UnresolvedAttribute]]. */ + implicit class StringToAttributeConversionHelper(val sc: StringContext) { + // Note that if we make ExpressionConversions an object rather than a trait, we can + // then make this a value class to avoid the small penalty of runtime instantiation. + def $(args: Any*): analysis.UnresolvedAttribute = { + analysis.UnresolvedAttribute(sc.s(args :_*)) + } + } + def sum(e: Expression) = Sum(e) def sumDistinct(e: Expression) = SumDistinct(e) def count(e: Expression) = Count(e) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index c0b9cf5163..ab88f3ad10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -56,6 +56,18 @@ class DslQuerySuite extends QueryTest { ) } + test("convert $\"attribute name\" into unresolved attribute") { + checkAnswer( + testData.where($"key" === 1).select($"value"), + Seq(Seq("1"))) + } + + test("convert Scala Symbol 'attrname into unresolved attribute") { + checkAnswer( + testData.where('key === 1).select('value), + Seq(Seq("1"))) + } + test("select *") { checkAnswer( testData.select(Star(None)), |