aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-01-29 19:09:08 -0800
committerReynold Xin <rxin@databricks.com>2015-01-29 19:09:08 -0800
commit80def9deb3bfc30d5b622b32aecb0322341a7f62 (patch)
tree9b42a7be87468d451aa3f7e3c3d06438b5b007e5 /sql
parent22271f969363fd139e6cfb5a2d95a2607fb4e572 (diff)
downloadspark-80def9deb3bfc30d5b622b32aecb0322341a7f62.tar.gz
spark-80def9deb3bfc30d5b622b32aecb0322341a7f62.tar.bz2
spark-80def9deb3bfc30d5b622b32aecb0322341a7f62.zip
[SQL] Support df("*") to select all columns in a data frame.
This PR makes Star a trait, and provides two implementations: UnresolvedStar (used for *, tblName.*) and ResolvedStar (used for df("*")). Author: Reynold Xin <rxin@databricks.com> Closes #4283 from rxin/df-star and squashes the following commits: c9cba3e [Reynold Xin] Removed mapFunction in UnresolvedStar. 1a3a1d7 [Reynold Xin] [SQL] Support df("*") to select all columns in a data frame.
Diffstat (limited to 'sql')
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala53
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala8
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala6
7 files changed, 54 insertions, 29 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index eaadbe9fd5..24a65f8f4d 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -348,7 +348,7 @@ class SqlParser extends AbstractSparkSQLParser {
)
protected lazy val baseExpression: Parser[Expression] =
- ( "*" ^^^ Star(None)
+ ( "*" ^^^ UnresolvedStar(None)
| primary
)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 71a738a0b2..6606028918 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -50,7 +50,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
override def qualifiers = throw new UnresolvedException(this, "qualifiers")
override lazy val resolved = false
- override def newInstance = this
+ override def newInstance() = this
override def withNullability(newNullability: Boolean) = this
override def withQualifiers(newQualifiers: Seq[String]) = this
override def withName(newName: String) = UnresolvedAttribute(name)
@@ -77,15 +77,10 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E
/**
* Represents all of the input attributes to a given relational operator, for example in
- * "SELECT * FROM ...".
- *
- * @param table an optional table that should be the target of the expansion. If omitted all
- * tables' columns are produced.
+ * "SELECT * FROM ...". A [[Star]] gets automatically expanded during analysis.
*/
-case class Star(
- table: Option[String],
- mapFunction: Attribute => Expression = identity[Attribute])
- extends Attribute with trees.LeafNode[Expression] {
+trait Star extends Attribute with trees.LeafNode[Expression] {
+ self: Product =>
override def name = throw new UnresolvedException(this, "name")
override def exprId = throw new UnresolvedException(this, "exprId")
@@ -94,29 +89,53 @@ case class Star(
override def qualifiers = throw new UnresolvedException(this, "qualifiers")
override lazy val resolved = false
- override def newInstance = this
+ override def newInstance() = this
override def withNullability(newNullability: Boolean) = this
override def withQualifiers(newQualifiers: Seq[String]) = this
override def withName(newName: String) = this
- def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = {
+ // Star gets expanded at runtime so we never evaluate a Star.
+ override def eval(input: Row = null): EvaluatedType =
+ throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+
+ def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression]
+}
+
+
+/**
+ * Represents all of the input attributes to a given relational operator, for example in
+ * "SELECT * FROM ...".
+ *
+ * @param table an optional table that should be the target of the expansion. If omitted all
+ * tables' columns are produced.
+ */
+case class UnresolvedStar(table: Option[String]) extends Star {
+
+ override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = {
val expandedAttributes: Seq[Attribute] = table match {
// If there is no table specified, use all input attributes.
case None => input
// If there is a table, pick out attributes that are part of this table.
case Some(t) => input.filter(_.qualifiers.filter(resolver(_, t)).nonEmpty)
}
- val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map {
+ expandedAttributes.zip(input).map {
case (n: NamedExpression, _) => n
case (e, originalAttribute) =>
Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers)
}
- mappedAttributes
}
- // Star gets expanded at runtime so we never evaluate a Star.
- override def eval(input: Row = null): EvaluatedType =
- throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
-
override def toString = table.map(_ + ".").getOrElse("") + "*"
}
+
+
+/**
+ * Represents all the resolved input attributes to a given relational operator. This is used
+ * in the data frame DSL.
+ *
+ * @param expressions Expressions to expand.
+ */
+case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
+ override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions
+ override def toString = expressions.mkString("ResolvedStar(", ", ", ")")
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 3aea337460..60060bf029 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -51,7 +51,9 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
test("union project *") {
val plan = (1 to 100)
.map(_ => testRelation)
- .fold[LogicalPlan](testRelation)((a,b) => a.select(Star(None)).select('a).unionAll(b.select(Star(None))))
+ .fold[LogicalPlan](testRelation) { (a, b) =>
+ a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None)))
+ }
assert(caseInsensitiveAnalyze(plan).resolved)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 68c9cb0c02..174c403059 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.language.implicitConversions
import org.apache.spark.sql.Dsl.lit
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedStar, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.types._
@@ -71,8 +71,8 @@ class Column(
* - "df.*" becomes an expression selecting all columns in data frame "df".
*/
def this(name: String) = this(name match {
- case "*" => Star(None)
- case _ if name.endsWith(".*") => Star(Some(name.substring(0, name.length - 2)))
+ case "*" => UnresolvedStar(None)
+ case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2)))
case _ => UnresolvedAttribute(name)
})
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 2694e81eac..1096e39659 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -31,7 +31,7 @@ import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
@@ -265,7 +265,7 @@ class DataFrame protected[sql](
*/
override def apply(colName: String): Column = colName match {
case "*" =>
- Column("*")
+ new Column(ResolvedStar(schema.fieldNames.map(resolve)))
case _ =>
val expr = resolve(colName)
new Column(Some(sqlContext), Some(Project(Seq(expr), logicalPlan)), expr)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 6428554ec7..2d464c2b53 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -31,10 +31,14 @@ class ColumnExpressionSuite extends QueryTest {
checkAnswer(testData.select($"*"), testData.collect().toSeq)
}
- ignore("star qualified by data frame object") {
+ test("star qualified by data frame object") {
// This is not yet supported.
val df = testData.toDataFrame
- checkAnswer(df.select(df("*")), df.collect().toSeq)
+ val goldAnswer = df.collect().toSeq
+ checkAnswer(df.select(df("*")), goldAnswer)
+
+ val df1 = df.select(df("*"), lit("abcd").as("litCol"))
+ checkAnswer(df1.select(df("*")), goldAnswer)
}
test("star qualified by table name") {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 5e29e57d93..399e58b259 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1002,11 +1002,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
}
/* Stars (*) */
- case Token("TOK_ALLCOLREF", Nil) => Star(None)
+ case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None)
// The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only
// has a single child which is tableName.
case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) =>
- Star(Some(name))
+ UnresolvedStar(Some(name))
/* Aggregate Functions */
case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg))
@@ -1145,7 +1145,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>
UnresolvedFunction(name, args.map(nodeToExpr))
case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) =>
- UnresolvedFunction(name, Star(None) :: Nil)
+ UnresolvedFunction(name, UnresolvedStar(None) :: Nil)
/* Literals */
case Token("TOK_NULL", Nil) => Literal(null, NullType)