aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala13
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala9
8 files changed, 40 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 61ee6f6f71..9b60943a1e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -239,7 +239,7 @@ object FunctionRegistry {
}
/** See usage above. */
- private def expression[T <: Expression](name: String)
+ def expression[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {
// See if we can find a constructor that accepts Seq[Expression]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index dbb2a09846..56cd8f22e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -31,6 +31,8 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{expression => FunctionExpression, FunctionBuilder}
+import org.apache.spark.sql.execution.expressions.SparkPartitionID
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.errors.DialectException
@@ -140,7 +142,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
// TODO how to handle the temp function per user session?
@transient
- protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin
+ protected[sql] lazy val functionRegistry: FunctionRegistry = {
+ val reg = FunctionRegistry.builtin
+ val extendedFunctions = List[(String, (ExpressionInfo, FunctionBuilder))](
+ FunctionExpression[SparkPartitionID]("spark__partition__id")
+ )
+ extendedFunctions.foreach { case(name, (info, fun)) => reg.registerFunction(name, info, fun) }
+ reg
+ }
@transient
protected[sql] lazy val analyzer: Analyzer =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
index 61ef079d89..98c8eab837 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.types.{IntegerType, DataType}
/**
* Expression that returns the current partition id of the Spark task.
*/
-private[sql] case object SparkPartitionID extends LeafExpression with Nondeterministic {
+private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterministic {
override def nullable: Boolean = false
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index cec61b66b1..0148991512 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -741,7 +741,7 @@ object functions {
* @group normal_funcs
* @since 1.4.0
*/
- def sparkPartitionId(): Column = execution.expressions.SparkPartitionID
+ def sparkPartitionId(): Column = execution.expressions.SparkPartitionID()
/**
* Computes the square root of the specified float value.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index c1516b450c..9b326c1635 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -51,6 +51,13 @@ class UDFSuite extends QueryTest {
df.selectExpr("count(distinct a)")
}
+ test("SPARK-8003 spark__partition__id") {
+ val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying")
+ df.registerTempTable("tmp_table")
+ checkAnswer(ctx.sql("select spark__partition__id() from tmp_table").toDF(), Row(0))
+ ctx.dropTempTable("tmp_table")
+ }
+
test("error reporting for incorrect number of arguments") {
val df = ctx.emptyDataFrame
val e = intercept[AnalysisException] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala
index 1c5a2ed2c0..b6e79ff9cc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala
@@ -27,6 +27,6 @@ class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("SparkPartitionID") {
- checkEvaluation(SparkPartitionID, 0)
+ checkEvaluation(SparkPartitionID(), 0)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 110f51a305..8b35c1275f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -38,6 +38,9 @@ import org.apache.spark.Logging
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{expression => FunctionExpression, FunctionBuilder}
+import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
+import org.apache.spark.sql.execution.expressions.SparkPartitionID
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.SQLConf.SQLConfEntry._
import org.apache.spark.sql.catalyst.{TableIdentifier, ParserDialect}
@@ -372,8 +375,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
// Note that HiveUDFs will be overridden by functions registered in this context.
@transient
- override protected[sql] lazy val functionRegistry: FunctionRegistry =
- new HiveFunctionRegistry(FunctionRegistry.builtin)
+ override protected[sql] lazy val functionRegistry: FunctionRegistry = {
+ val reg = new HiveFunctionRegistry(FunctionRegistry.builtin)
+ val extendedFunctions = List[(String, (ExpressionInfo, FunctionBuilder))](
+ FunctionExpression[SparkPartitionID]("spark__partition__id")
+ )
+ extendedFunctions.foreach { case(name, (info, fun)) => reg.registerFunction(name, info, fun) }
+ reg
+ }
/* An analyzer that uses the Hive metastore. */
@transient
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
index 4056dee777..9cea5d413c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
@@ -17,13 +17,14 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.{Row, QueryTest}
case class FunctionResult(f1: String, f2: String)
class UDFSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.hive.test.TestHive
+ import ctx.implicits._
test("UDF case insensitive") {
ctx.udf.register("random0", () => { Math.random() })
@@ -33,4 +34,10 @@ class UDFSuite extends QueryTest {
assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5)
}
+
+ test("SPARK-8003 spark__partition__id") {
+ val df = Seq((1, "Two Fiiiiive")).toDF("id", "saying")
+ ctx.registerDataFrameAsTable(df, "test_table")
+ checkAnswer(ctx.sql("select spark__partition__id() from test_table LIMIT 1").toDF(), Row(0))
+ }
}