From 46395db80e3304e3f3a1ebdc8aadb8f2819b48b4 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 30 Jun 2016 12:03:54 -0700 Subject: [SPARK-16289][SQL] Implement posexplode table generating function ## What changes were proposed in this pull request? This PR implements `posexplode` table generating function. Currently, master branch raises the following exception for `map` argument. It's different from Hive. **Before** ```scala scala> sql("select posexplode(map('a', 1, 'b', 2))").show org.apache.spark.sql.AnalysisException: No handler for Hive UDF ... posexplode() takes an array as a parameter; line 1 pos 7 ``` **After** ```scala scala> sql("select posexplode(map('a', 1, 'b', 2))").show +---+---+-----+ |pos|key|value| +---+---+-----+ | 0| a| 1| | 1| b| 2| +---+---+-----+ ``` For `array` argument, `after` is the same with `before`. ``` scala> sql("select posexplode(array(1, 2, 3))").show +---+---+ |pos|col| +---+---+ | 0| 1| | 1| 2| | 2| 3| +---+---+ ``` ## How was this patch tested? Pass the Jenkins tests with newly added testcases. Author: Dongjoon Hyun Closes #13971 from dongjoon-hyun/SPARK-16289. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 17 ++++ R/pkg/R/generics.R | 4 + R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- python/pyspark/sql/functions.py | 21 +++++ .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/generators.scala | 66 +++++++++++++--- .../analysis/ExpressionTypeCheckingSuite.scala | 2 + .../expressions/GeneratorExpressionSuite.scala | 71 +++++++++++++++++ .../main/scala/org/apache/spark/sql/Column.scala | 1 + .../scala/org/apache/spark/sql/functions.scala | 8 ++ .../apache/spark/sql/ColumnExpressionSuite.scala | 60 -------------- .../apache/spark/sql/GeneratorFunctionSuite.scala | 92 ++++++++++++++++++++++ .../apache/spark/sql/hive/HiveSessionCatalog.scala | 2 +- 14 files changed, 276 insertions(+), 72 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index e0ffde922d..abc65887bc 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -234,6 +234,7 @@ exportMethods("%in%", "over", "percent_rank", "pmod", + "posexplode", "quarter", "rand", "randn", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 09e5afa970..52d46f9d76 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2934,3 +2934,20 @@ setMethod("sort_array", jc <- callJStatic("org.apache.spark.sql.functions", "sort_array", x@jc, asc) column(jc) }) + +#' posexplode +#' +#' Creates a new row for each element with position in the given array or map column. +#' +#' @rdname posexplode +#' @name posexplode +#' @family collection_funcs +#' @export +#' @examples \dontrun{posexplode(df$c)} +#' @note posexplode since 2.1.0 +setMethod("posexplode", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "posexplode", x@jc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0e4350f861..d9080b6b32 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1050,6 +1050,10 @@ setGeneric("percent_rank", function(x) { standardGeneric("percent_rank") }) #' @export setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) +#' @rdname posexplode +#' @export +setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) + #' @rdname quarter #' @export setGeneric("quarter", function(x) { standardGeneric("quarter") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index d4662ad4e3..588c217f3c 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1065,7 +1065,7 @@ test_that("column functions", { c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c) c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) - c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c) + c7 <- mean(c) + min(c) + month(c) + negate(c) + posexplode(c) + quarter(c) c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + monotonically_increasing_id() c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c) c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 15cefc8cf1..7a7345170c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1637,6 +1637,27 @@ def explode(col): return Column(jc) +@since(2.1) +def posexplode(col): + """Returns a new row for each element with position in the given array or map. + + >>> from pyspark.sql import Row + >>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF.select(posexplode(eDF.intlist)).collect() + [Row(pos=0, col=1), Row(pos=1, col=2), Row(pos=2, col=3)] + + >>> eDF.select(posexplode(eDF.mapfield)).show() + +---+---+-----+ + |pos|key|value| + +---+---+-----+ + | 0| a| b| + +---+---+-----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.posexplode(_to_java_column(col)) + return Column(jc) + + @ignore_unicode_prefix @since(1.6) def get_json_object(col, path): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3f9227a8ae..3fbdb2ab57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -176,6 +176,7 @@ object FunctionRegistry { expression[NullIf]("nullif"), expression[Nvl]("nvl"), expression[Nvl2]("nvl2"), + expression[PosExplode]("posexplode"), expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 12c35644e5..4e91cc5aec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -94,13 +94,10 @@ case class UserDefinedGenerator( } /** - * Given an input array produces a sequence of rows for each value in the array. + * A base class for Explode and PosExplode */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of a map into multiple rows and columns.") -// scalastyle:on line.size.limit -case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback { +abstract class ExplodeBase(child: Expression, position: Boolean) + extends UnaryExpression with Generator with CodegenFallback with Serializable { override def children: Seq[Expression] = child :: Nil @@ -115,9 +112,26 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) override def elementSchema: StructType = child.dataType match { - case ArrayType(et, containsNull) => new StructType().add("col", et, containsNull) + case ArrayType(et, containsNull) => + if (position) { + new StructType() + .add("pos", IntegerType, false) + .add("col", et, containsNull) + } else { + new StructType() + .add("col", et, containsNull) + } case MapType(kt, vt, valueContainsNull) => - new StructType().add("key", kt, false).add("value", vt, valueContainsNull) + if (position) { + new StructType() + .add("pos", IntegerType, false) + .add("key", kt, false) + .add("value", vt, valueContainsNull) + } else { + new StructType() + .add("key", kt, false) + .add("value", vt, valueContainsNull) + } } override def eval(input: InternalRow): TraversableOnce[InternalRow] = { @@ -129,7 +143,7 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit } else { val rows = new Array[InternalRow](inputArray.numElements()) inputArray.foreach(et, (i, e) => { - rows(i) = InternalRow(e) + rows(i) = if (position) InternalRow(i, e) else InternalRow(e) }) rows } @@ -141,7 +155,7 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit val rows = new Array[InternalRow](inputMap.numElements()) var i = 0 inputMap.foreach(kt, vt, (k, v) => { - rows(i) = InternalRow(k, v) + rows(i) = if (position) InternalRow(i, k, v) else InternalRow(k, v) i += 1 }) rows @@ -149,3 +163,35 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit } } } + +/** + * Given an input array produces a sequence of rows for each value in the array. + * + * {{{ + * SELECT explode(array(10,20)) -> + * 10 + * 20 + * }}} + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of map a into multiple rows and columns.", + extended = "> SELECT _FUNC_(array(10,20));\n 10\n 20") +// scalastyle:on line.size.limit +case class Explode(child: Expression) extends ExplodeBase(child, position = false) + +/** + * Given an input array produces a sequence of rows for each position and value in the array. + * + * {{{ + * SELECT posexplode(array(10,20)) -> + * 0 10 + * 1 20 + * }}} + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a) - Separates the elements of array a into multiple rows with positions, or the elements of a map into multiple rows and columns with positions.", + extended = "> SELECT _FUNC_(array(10,20));\n 0\t10\n 1\t20") +// scalastyle:on line.size.limit +case class PosExplode(child: Expression) extends ExplodeBase(child, position = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 54436ea9a4..76e42d9afa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -166,6 +166,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(new Murmur3Hash(Nil), "function hash requires at least one argument") assertError(Explode('intField), "input to function explode should be array or map type") + assertError(PosExplode('intField), + "input to function explode should be array or map type") } test("check types for CreateNamedStruct") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala new file mode 100644 index 0000000000..2aba84141b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.UTF8String + +class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + private def checkTuple(actual: ExplodeBase, expected: Seq[InternalRow]): Unit = { + assert(actual.eval(null).toSeq === expected) + } + + private final val int_array = Seq(1, 2, 3) + private final val str_array = Seq("a", "b", "c") + + test("explode") { + val int_correct_answer = Seq(Seq(1), Seq(2), Seq(3)) + val str_correct_answer = Seq( + Seq(UTF8String.fromString("a")), + Seq(UTF8String.fromString("b")), + Seq(UTF8String.fromString("c"))) + + checkTuple( + Explode(CreateArray(Seq.empty)), + Seq.empty) + + checkTuple( + Explode(CreateArray(int_array.map(Literal(_)))), + int_correct_answer.map(InternalRow.fromSeq(_))) + + checkTuple( + Explode(CreateArray(str_array.map(Literal(_)))), + str_correct_answer.map(InternalRow.fromSeq(_))) + } + + test("posexplode") { + val int_correct_answer = Seq(Seq(0, 1), Seq(1, 2), Seq(2, 3)) + val str_correct_answer = Seq( + Seq(0, UTF8String.fromString("a")), + Seq(1, UTF8String.fromString("b")), + Seq(2, UTF8String.fromString("c"))) + + checkTuple( + PosExplode(CreateArray(Seq.empty)), + Seq.empty) + + checkTuple( + PosExplode(CreateArray(int_array.map(Literal(_)))), + int_correct_answer.map(InternalRow.fromSeq(_))) + + checkTuple( + PosExplode(CreateArray(str_array.map(Literal(_)))), + str_correct_answer.map(InternalRow.fromSeq(_))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 9f35107e5b..a46d1949e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -159,6 +159,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { // Leave an unaliased generator with an empty list of names since the analyzer will generate // the correct defaults after the nested expression's type has been resolved. case explode: Explode => MultiAlias(explode, Nil) + case explode: PosExplode => MultiAlias(explode, Nil) case jt: JsonTuple => MultiAlias(jt, Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e8bd489be3..c8782df146 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2721,6 +2721,14 @@ object functions { */ def explode(e: Column): Column = withExpr { Explode(e.expr) } + /** + * Creates a new row for each element with position in the given array or map column. + * + * @group collection_funcs + * @since 2.1.0 + */ + def posexplode(e: Column): Column = withExpr { PosExplode(e.expr) } + /** * Extracts json object from a json string based on json path specified, and returns json string * of the extracted json object. It will return null if the input json string is invalid. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index a66c83dea0..a170fae577 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -122,66 +122,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value") } - test("single explode") { - val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") - checkAnswer( - df.select(explode('intList)), - Row(1) :: Row(2) :: Row(3) :: Nil) - } - - test("explode and other columns") { - val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") - - checkAnswer( - df.select($"a", explode('intList)), - Row(1, 1) :: - Row(1, 2) :: - Row(1, 3) :: Nil) - - checkAnswer( - df.select($"*", explode('intList)), - Row(1, Seq(1, 2, 3), 1) :: - Row(1, Seq(1, 2, 3), 2) :: - Row(1, Seq(1, 2, 3), 3) :: Nil) - } - - test("aliased explode") { - val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") - - checkAnswer( - df.select(explode('intList).as('int)).select('int), - Row(1) :: Row(2) :: Row(3) :: Nil) - - checkAnswer( - df.select(explode('intList).as('int)).select(sum('int)), - Row(6) :: Nil) - } - - test("explode on map") { - val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") - - checkAnswer( - df.select(explode('map)), - Row("a", "b")) - } - - test("explode on map with aliases") { - val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") - - checkAnswer( - df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), - Row("a", "b")) - } - - test("self join explode") { - val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") - val exploded = df.select(explode('intList).as('i)) - - checkAnswer( - exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")), - Row(3) :: Nil) - } - test("collect on column produced by a binary operator") { val df = Seq((1, 2, 3)).toDF("a", "b", "c") checkAnswer(df.select(df("a") + df("b")), Seq(Row(3))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala new file mode 100644 index 0000000000..1f0ef34ec1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("single explode") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + checkAnswer( + df.select(explode('intList)), + Row(1) :: Row(2) :: Row(3) :: Nil) + } + + test("single posexplode") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + checkAnswer( + df.select(posexplode('intList)), + Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil) + } + + test("explode and other columns") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + + checkAnswer( + df.select($"a", explode('intList)), + Row(1, 1) :: + Row(1, 2) :: + Row(1, 3) :: Nil) + + checkAnswer( + df.select($"*", explode('intList)), + Row(1, Seq(1, 2, 3), 1) :: + Row(1, Seq(1, 2, 3), 2) :: + Row(1, Seq(1, 2, 3), 3) :: Nil) + } + + test("aliased explode") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + + checkAnswer( + df.select(explode('intList).as('int)).select('int), + Row(1) :: Row(2) :: Row(3) :: Nil) + + checkAnswer( + df.select(explode('intList).as('int)).select(sum('int)), + Row(6) :: Nil) + } + + test("explode on map") { + val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") + + checkAnswer( + df.select(explode('map)), + Row("a", "b")) + } + + test("explode on map with aliases") { + val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") + + checkAnswer( + df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), + Row("a", "b")) + } + + test("self join explode") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + val exploded = df.select(explode('intList).as('i)) + + checkAnswer( + exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")), + Row(3) :: Nil) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index fa560a044b..195591fd9d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -245,6 +245,6 @@ private[sql] class HiveSessionCatalog( "xpath_number", "xpath_short", "xpath_string", // table generating function - "inline", "posexplode" + "inline" ) } -- cgit v1.2.3