aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-07-06 10:54:43 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-06 10:54:43 +0800
commitd0d28507cacfca5919dbfb4269892d58b62e8662 (patch)
tree474edb76ba7642ca385c743a5aad5bf437f46ef4 /sql/catalyst
parentfdde7d0aa0ef69d0e9a88cf712601bba1d5b0706 (diff)
downloadspark-d0d28507cacfca5919dbfb4269892d58b62e8662.tar.gz
spark-d0d28507cacfca5919dbfb4269892d58b62e8662.tar.bz2
spark-d0d28507cacfca5919dbfb4269892d58b62e8662.zip
[SPARK-16286][SQL] Implement stack table generating function
## What changes were proposed in this pull request? This PR implements `stack` table generating function. ## How was this patch tested? Pass the Jenkins tests including new testcases. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #14033 from dongjoon-hyun/SPARK-16286.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala53
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala18
3 files changed, 72 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 021bec7f5f..f6ebcaeded 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -182,6 +182,7 @@ object FunctionRegistry {
expression[PosExplode]("posexplode"),
expression[Rand]("rand"),
expression[Randn]("randn"),
+ expression[Stack]("stack"),
expression[CreateStruct]("struct"),
expression[CaseWhen]("when"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 99b97c8ea2..9d5c856a23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -94,6 +94,59 @@ case class UserDefinedGenerator(
}
/**
+ * Separate v1, ..., vk into n rows. Each row will have k/n columns. n must be constant.
+ * {{{
+ * SELECT stack(2, 1, 2, 3) ->
+ * 1 2
+ * 3 NULL
+ * }}}
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(n, v1, ..., vk) - Separate v1, ..., vk into n rows.",
+ extended = "> SELECT _FUNC_(2, 1, 2, 3);\n [1,2]\n [3,null]")
+case class Stack(children: Seq[Expression])
+ extends Expression with Generator with CodegenFallback {
+
+ private lazy val numRows = children.head.eval().asInstanceOf[Int]
+ private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.length <= 1) {
+ TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.")
+ } else if (children.head.dataType != IntegerType || !children.head.foldable || numRows < 1) {
+ TypeCheckResult.TypeCheckFailure("The number of rows must be a positive constant integer.")
+ } else {
+ for (i <- 1 until children.length) {
+ val j = (i - 1) % numFields
+ if (children(i).dataType != elementSchema.fields(j).dataType) {
+ return TypeCheckResult.TypeCheckFailure(
+ s"Argument ${j + 1} (${elementSchema.fields(j).dataType}) != " +
+ s"Argument $i (${children(i).dataType})")
+ }
+ }
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
+ override def elementSchema: StructType =
+ StructType(children.tail.take(numFields).zipWithIndex.map {
+ case (e, index) => StructField(s"col$index", e.dataType)
+ })
+
+ override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
+ val values = children.tail.map(_.eval(input)).toArray
+ for (row <- 0 until numRows) yield {
+ val fields = new Array[Any](numFields)
+ for (col <- 0 until numFields) {
+ val index = row * numFields + col
+ fields.update(col, if (index < values.length) values(index) else null)
+ }
+ InternalRow(fields: _*)
+ }
+ }
+}
+
+/**
* A base class for Explode and PosExplode
*/
abstract class ExplodeBase(child: Expression, position: Boolean)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
index e79f89b497..e29dfa41f1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
@@ -63,4 +63,22 @@ class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
))),
correct_answer)
}
+
+ test("stack") {
+ checkTuple(Stack(Seq(1, 1).map(Literal(_))), Seq(create_row(1)))
+ checkTuple(Stack(Seq(1, 1, 2).map(Literal(_))), Seq(create_row(1, 2)))
+ checkTuple(Stack(Seq(2, 1, 2).map(Literal(_))), Seq(create_row(1), create_row(2)))
+ checkTuple(Stack(Seq(2, 1, 2, 3).map(Literal(_))), Seq(create_row(1, 2), create_row(3, null)))
+ checkTuple(Stack(Seq(3, 1, 2, 3).map(Literal(_))), Seq(1, 2, 3).map(create_row(_)))
+ checkTuple(Stack(Seq(4, 1, 2, 3).map(Literal(_))), Seq(1, 2, 3, null).map(create_row(_)))
+
+ checkTuple(
+ Stack(Seq(3, 1, 1.0, "a", 2, 2.0, "b", 3, 3.0, "c").map(Literal(_))),
+ Seq(create_row(1, 1.0, "a"), create_row(2, 2.0, "b"), create_row(3, 3.0, "c")))
+
+ assert(Stack(Seq(Literal(1))).checkInputDataTypes().isFailure)
+ assert(Stack(Seq(Literal(1.0))).checkInputDataTypes().isFailure)
+ assert(Stack(Seq(Literal(1), Literal(1), Literal(1.0))).checkInputDataTypes().isSuccess)
+ assert(Stack(Seq(Literal(2), Literal(1), Literal(1.0))).checkInputDataTypes().isFailure)
+ }
}