diff options
author | Joseph Batchik <josephbatchik@gmail.com> | 2015-07-28 14:39:25 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-28 14:39:25 -0700 |
commit | b88b868eb378bdb7459978842b5572a0b498f412 (patch) | |
tree | 4396c86c3a5fe95e1cb8b0325cd3f7e95e01aee1 | |
parent | 8d5bb5283c3cc9180ef34b05be4a715d83073b1e (diff) | |
download | spark-b88b868eb378bdb7459978842b5572a0b498f412.tar.gz spark-b88b868eb378bdb7459978842b5572a0b498f412.tar.bz2 spark-b88b868eb378bdb7459978842b5572a0b498f412.zip |
[SPARK-8003][SQL] Added virtual column support to Spark
Added virtual column support by adding a new resolution role to the query analyzer. Additional virtual columns can be added by adding case expressions to [the new rule](https://github.com/JDrit/spark/blob/virt_columns/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala#L1026) and my modifying the [logical plan](https://github.com/JDrit/spark/blob/virt_columns/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala#L216) to resolve them.
This also solves [SPARK-8003](https://issues.apache.org/jira/browse/SPARK-8003)
This allows you to perform queries such as:
```sql
select spark__partition__id, count(*) as c from table group by spark__partition__id;
```
Author: Joseph Batchik <josephbatchik@gmail.com>
Author: JD <jd@csh.rit.edu>
Closes #7478 from JDrit/virt_columns and squashes the following commits:
7932bf0 [Joseph Batchik] adding spark__partition__id to hive as well
f8a9c6c [Joseph Batchik] merging in master
e49da48 [JD] fixes for @rxin's suggestions
60e120b [JD] fixing test in merge
4bf8554 [JD] merging in master
c68bc0f [Joseph Batchik] Adding function register ability to SQLContext and adding a function for spark__partition__id()
8 files changed, 40 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 61ee6f6f71..9b60943a1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -239,7 +239,7 @@ object FunctionRegistry { } /** See usage above. */ - private def expression[T <: Expression](name: String) + def expression[T <: Expression](name: String) (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { // See if we can find a constructor that accepts Seq[Expression] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dbb2a09846..56cd8f22e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,6 +31,8 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{expression => FunctionExpression, FunctionBuilder} +import org.apache.spark.sql.execution.expressions.SparkPartitionID import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.errors.DialectException @@ -140,7 +142,14 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO how to handle the temp function per user session? @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin + protected[sql] lazy val functionRegistry: FunctionRegistry = { + val reg = FunctionRegistry.builtin + val extendedFunctions = List[(String, (ExpressionInfo, FunctionBuilder))]( + FunctionExpression[SparkPartitionID]("spark__partition__id") + ) + extendedFunctions.foreach { case(name, (info, fun)) => reg.registerFunction(name, info, fun) } + reg + } @transient protected[sql] lazy val analyzer: Analyzer = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 61ef079d89..98c8eab837 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.{IntegerType, DataType} /** * Expression that returns the current partition id of the Spark task. */ -private[sql] case object SparkPartitionID extends LeafExpression with Nondeterministic { +private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cec61b66b1..0148991512 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -741,7 +741,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def sparkPartitionId(): Column = execution.expressions.SparkPartitionID + def sparkPartitionId(): Column = execution.expressions.SparkPartitionID() /** * Computes the square root of the specified float value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index c1516b450c..9b326c1635 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -51,6 +51,13 @@ class UDFSuite extends QueryTest { df.selectExpr("count(distinct a)") } + test("SPARK-8003 spark__partition__id") { + val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") + df.registerTempTable("tmp_table") + checkAnswer(ctx.sql("select spark__partition__id() from tmp_table").toDF(), Row(0)) + ctx.dropTempTable("tmp_table") + } + test("error reporting for incorrect number of arguments") { val df = ctx.emptyDataFrame val e = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala index 1c5a2ed2c0..b6e79ff9cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala @@ -27,6 +27,6 @@ class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SparkPartitionID") { - checkEvaluation(SparkPartitionID, 0) + checkEvaluation(SparkPartitionID(), 0) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 110f51a305..8b35c1275f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -38,6 +38,9 @@ import org.apache.spark.Logging import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{expression => FunctionExpression, FunctionBuilder} +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo +import org.apache.spark.sql.execution.expressions.SparkPartitionID import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ import org.apache.spark.sql.catalyst.{TableIdentifier, ParserDialect} @@ -372,8 +375,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // Note that HiveUDFs will be overridden by functions registered in this context. @transient - override protected[sql] lazy val functionRegistry: FunctionRegistry = - new HiveFunctionRegistry(FunctionRegistry.builtin) + override protected[sql] lazy val functionRegistry: FunctionRegistry = { + val reg = new HiveFunctionRegistry(FunctionRegistry.builtin) + val extendedFunctions = List[(String, (ExpressionInfo, FunctionBuilder))]( + FunctionExpression[SparkPartitionID]("spark__partition__id") + ) + extendedFunctions.foreach { case(name, (info, fun)) => reg.registerFunction(name, info, fun) } + reg + } /* An analyzer that uses the Hive metastore. */ @transient diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 4056dee777..9cea5d413c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{Row, QueryTest} case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + import ctx.implicits._ test("UDF case insensitive") { ctx.udf.register("random0", () => { Math.random() }) @@ -33,4 +34,10 @@ class UDFSuite extends QueryTest { assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } + + test("SPARK-8003 spark__partition__id") { + val df = Seq((1, "Two Fiiiiive")).toDF("id", "saying") + ctx.registerDataFrameAsTable(df, "test_table") + checkAnswer(ctx.sql("select spark__partition__id() from test_table LIMIT 1").toDF(), Row(0)) + } } |