diff options
author | Dongjoon Hyun <dongjoon@apache.org> | 2016-07-06 10:54:43 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-07-06 10:54:43 +0800 |
commit | d0d28507cacfca5919dbfb4269892d58b62e8662 (patch) | |
tree | 474edb76ba7642ca385c743a5aad5bf437f46ef4 /sql/catalyst/src/main | |
parent | fdde7d0aa0ef69d0e9a88cf712601bba1d5b0706 (diff) | |
download | spark-d0d28507cacfca5919dbfb4269892d58b62e8662.tar.gz spark-d0d28507cacfca5919dbfb4269892d58b62e8662.tar.bz2 spark-d0d28507cacfca5919dbfb4269892d58b62e8662.zip |
[SPARK-16286][SQL] Implement stack table generating function
## What changes were proposed in this pull request?
This PR implements `stack` table generating function.
## How was this patch tested?
Pass the Jenkins tests including new testcases.
Author: Dongjoon Hyun <dongjoon@apache.org>
Closes #14033 from dongjoon-hyun/SPARK-16286.
Diffstat (limited to 'sql/catalyst/src/main')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala | 1 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala | 53 |
2 files changed, 54 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 021bec7f5f..f6ebcaeded 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -182,6 +182,7 @@ object FunctionRegistry { expression[PosExplode]("posexplode"), expression[Rand]("rand"), expression[Randn]("randn"), + expression[Stack]("stack"), expression[CreateStruct]("struct"), expression[CaseWhen]("when"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 99b97c8ea2..9d5c856a23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -94,6 +94,59 @@ case class UserDefinedGenerator( } /** + * Separate v1, ..., vk into n rows. Each row will have k/n columns. n must be constant. + * {{{ + * SELECT stack(2, 1, 2, 3) -> + * 1 2 + * 3 NULL + * }}} + */ +@ExpressionDescription( + usage = "_FUNC_(n, v1, ..., vk) - Separate v1, ..., vk into n rows.", + extended = "> SELECT _FUNC_(2, 1, 2, 3);\n [1,2]\n [3,null]") +case class Stack(children: Seq[Expression]) + extends Expression with Generator with CodegenFallback { + + private lazy val numRows = children.head.eval().asInstanceOf[Int] + private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.") + } else if (children.head.dataType != IntegerType || !children.head.foldable || numRows < 1) { + TypeCheckResult.TypeCheckFailure("The number of rows must be a positive constant integer.") + } else { + for (i <- 1 until children.length) { + val j = (i - 1) % numFields + if (children(i).dataType != elementSchema.fields(j).dataType) { + return TypeCheckResult.TypeCheckFailure( + s"Argument ${j + 1} (${elementSchema.fields(j).dataType}) != " + + s"Argument $i (${children(i).dataType})") + } + } + TypeCheckResult.TypeCheckSuccess + } + } + + override def elementSchema: StructType = + StructType(children.tail.take(numFields).zipWithIndex.map { + case (e, index) => StructField(s"col$index", e.dataType) + }) + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val values = children.tail.map(_.eval(input)).toArray + for (row <- 0 until numRows) yield { + val fields = new Array[Any](numFields) + for (col <- 0 until numFields) { + val index = row * numFields + col + fields.update(col, if (index < values.length) values(index) else null) + } + InternalRow(fields: _*) + } + } +} + +/** * A base class for Explode and PosExplode */ abstract class ExplodeBase(child: Expression, position: Boolean) |