From d0d28507cacfca5919dbfb4269892d58b62e8662 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 6 Jul 2016 10:54:43 +0800 Subject: [SPARK-16286][SQL] Implement stack table generating function ## What changes were proposed in this pull request? This PR implements `stack` table generating function. ## How was this patch tested? Pass the Jenkins tests including new testcases. Author: Dongjoon Hyun Closes #14033 from dongjoon-hyun/SPARK-16286. --- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/generators.scala | 53 ++++++++++++++++++++++ .../expressions/GeneratorExpressionSuite.scala | 18 ++++++++ .../apache/spark/sql/GeneratorFunctionSuite.scala | 53 ++++++++++++++++++++++ .../apache/spark/sql/hive/HiveSessionCatalog.scala | 2 +- 5 files changed, 126 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 021bec7f5f..f6ebcaeded 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -182,6 +182,7 @@ object FunctionRegistry { expression[PosExplode]("posexplode"), expression[Rand]("rand"), expression[Randn]("randn"), + expression[Stack]("stack"), expression[CreateStruct]("struct"), expression[CaseWhen]("when"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 99b97c8ea2..9d5c856a23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -93,6 +93,59 @@ case class UserDefinedGenerator( override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" } +/** + * Separate v1, ..., vk into n rows. Each row will have k/n columns. n must be constant. + * {{{ + * SELECT stack(2, 1, 2, 3) -> + * 1 2 + * 3 NULL + * }}} + */ +@ExpressionDescription( + usage = "_FUNC_(n, v1, ..., vk) - Separate v1, ..., vk into n rows.", + extended = "> SELECT _FUNC_(2, 1, 2, 3);\n [1,2]\n [3,null]") +case class Stack(children: Seq[Expression]) + extends Expression with Generator with CodegenFallback { + + private lazy val numRows = children.head.eval().asInstanceOf[Int] + private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.") + } else if (children.head.dataType != IntegerType || !children.head.foldable || numRows < 1) { + TypeCheckResult.TypeCheckFailure("The number of rows must be a positive constant integer.") + } else { + for (i <- 1 until children.length) { + val j = (i - 1) % numFields + if (children(i).dataType != elementSchema.fields(j).dataType) { + return TypeCheckResult.TypeCheckFailure( + s"Argument ${j + 1} (${elementSchema.fields(j).dataType}) != " + + s"Argument $i (${children(i).dataType})") + } + } + TypeCheckResult.TypeCheckSuccess + } + } + + override def elementSchema: StructType = + StructType(children.tail.take(numFields).zipWithIndex.map { + case (e, index) => StructField(s"col$index", e.dataType) + }) + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val values = children.tail.map(_.eval(input)).toArray + for (row <- 0 until numRows) yield { + val fields = new Array[Any](numFields) + for (col <- 0 until numFields) { + val index = row * numFields + col + fields.update(col, if (index < values.length) values(index) else null) + } + InternalRow(fields: _*) + } + } +} + /** * A base class for Explode and PosExplode */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala index e79f89b497..e29dfa41f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala @@ -63,4 +63,22 @@ class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { ))), correct_answer) } + + test("stack") { + checkTuple(Stack(Seq(1, 1).map(Literal(_))), Seq(create_row(1))) + checkTuple(Stack(Seq(1, 1, 2).map(Literal(_))), Seq(create_row(1, 2))) + checkTuple(Stack(Seq(2, 1, 2).map(Literal(_))), Seq(create_row(1), create_row(2))) + checkTuple(Stack(Seq(2, 1, 2, 3).map(Literal(_))), Seq(create_row(1, 2), create_row(3, null))) + checkTuple(Stack(Seq(3, 1, 2, 3).map(Literal(_))), Seq(1, 2, 3).map(create_row(_))) + checkTuple(Stack(Seq(4, 1, 2, 3).map(Literal(_))), Seq(1, 2, 3, null).map(create_row(_))) + + checkTuple( + Stack(Seq(3, 1, 1.0, "a", 2, 2.0, "b", 3, 3.0, "c").map(Literal(_))), + Seq(create_row(1, 1.0, "a"), create_row(2, 2.0, "b"), create_row(3, 3.0, "c"))) + + assert(Stack(Seq(Literal(1))).checkInputDataTypes().isFailure) + assert(Stack(Seq(Literal(1.0))).checkInputDataTypes().isFailure) + assert(Stack(Seq(Literal(1), Literal(1), Literal(1.0))).checkInputDataTypes().isSuccess) + assert(Stack(Seq(Literal(2), Literal(1), Literal(1.0))).checkInputDataTypes().isFailure) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index d8a0aa4d52..aedc0a8d6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -23,6 +23,59 @@ import org.apache.spark.sql.test.SharedSQLContext class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { import testImplicits._ + test("stack") { + val df = spark.range(1) + + // Empty DataFrame suppress the result generation + checkAnswer(spark.emptyDataFrame.selectExpr("stack(1, 1, 2, 3)"), Nil) + + // Rows & columns + checkAnswer(df.selectExpr("stack(1, 1, 2, 3)"), Row(1, 2, 3) :: Nil) + checkAnswer(df.selectExpr("stack(2, 1, 2, 3)"), Row(1, 2) :: Row(3, null) :: Nil) + checkAnswer(df.selectExpr("stack(3, 1, 2, 3)"), Row(1) :: Row(2) :: Row(3) :: Nil) + checkAnswer(df.selectExpr("stack(4, 1, 2, 3)"), Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + + // Various column types + checkAnswer(df.selectExpr("stack(3, 1, 1.1, 'a', 2, 2.2, 'b', 3, 3.3, 'c')"), + Row(1, 1.1, "a") :: Row(2, 2.2, "b") :: Row(3, 3.3, "c") :: Nil) + + // Repeat generation at every input row + checkAnswer(spark.range(2).selectExpr("stack(2, 1, 2, 3)"), + Row(1, 2) :: Row(3, null) :: Row(1, 2) :: Row(3, null) :: Nil) + + // The first argument must be a positive constant integer. + val m = intercept[AnalysisException] { + df.selectExpr("stack(1.1, 1, 2, 3)") + }.getMessage + assert(m.contains("The number of rows must be a positive constant integer.")) + val m2 = intercept[AnalysisException] { + df.selectExpr("stack(-1, 1, 2, 3)") + }.getMessage + assert(m2.contains("The number of rows must be a positive constant integer.")) + + // The data for the same column should have the same type. + val m3 = intercept[AnalysisException] { + df.selectExpr("stack(2, 1, '2.2')") + }.getMessage + assert(m3.contains("data type mismatch: Argument 1 (IntegerType) != Argument 2 (StringType)")) + + // stack on column data + val df2 = Seq((2, 1, 2, 3)).toDF("n", "a", "b", "c") + checkAnswer(df2.selectExpr("stack(2, a, b, c)"), Row(1, 2) :: Row(3, null) :: Nil) + + val m4 = intercept[AnalysisException] { + df2.selectExpr("stack(n, a, b, c)") + }.getMessage + assert(m4.contains("The number of rows must be a positive constant integer.")) + + val df3 = Seq((2, 1, 2.0)).toDF("n", "a", "b") + val m5 = intercept[AnalysisException] { + df3.selectExpr("stack(2, a, b)") + }.getMessage + assert(m5.contains("data type mismatch: Argument 1 (IntegerType) != Argument 2 (DoubleType)")) + + } + test("single explode") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index ebb6711f6a..fdc4c18e70 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -236,7 +236,7 @@ private[sql] class HiveSessionCatalog( // str_to_map, windowingtablefunction. private val hiveFunctions = Seq( "hash", "java_method", "histogram_numeric", - "parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map", + "parse_url", "percentile", "percentile_approx", "reflect", "sentences", "str_to_map", "xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long", "xpath_number", "xpath_short", "xpath_string" ) -- cgit v1.2.3